| <h1>The Ultra-Scale Playbook: Training LLMs on GPU Clusters</h1> | |
| <p>Fueled by the <a href="https://arxiv.org/abs/2001.08361">scaling laws</a>, the trend of training ever larger language models on vaster amounts of data has been driving progress in AI for the past couple years. Initially, the development of the largest models happened exclusively behind closed doors of a handful of research labs but recently opened up more with the release of models such as <a href="https://ai.meta.com/blog/meta-llama-3-1/">Llama 3.1 405B</a> and <a href="https://arxiv.org/abs/2501.12948">DeepSeek R1</a>. While these models have <a href="https://huggingface.co/meta-llama">openly shared</a> <a href="https://huggingface.co/deepseek-ai">weights</a> and their training recipes are described in <a href="https://ai.meta.com/research/publications/the-llama-3-herd-of-models/">technical</a> <a href="https://github.com/deepseek-ai/DeepSeek-R1/blob/main/DeepSeek_R1.pdf">reports</a>, the challenging engineering to involved to train at the necessary infrastructure scale is still hidden between the lines of a handful of papers and complex training frameworks. This ~~long blog post~~ open-source book is here to open this black box!</p> | |
| <p>In this book we invite you to follow us in the wonderful world of scaling training of Large Language Models to tens, hundreds, thousands of GPUs. It assumes you know the basics on LLM architecture and training, but are new to distributed training. This writing can be seen as a second part of a trilogy following our first blog on processing data for pre-training, the so-called “<a href="https://huggingface.co/spaces/HuggingFaceFW/blogpost-fineweb-v1">FineWeb blog post</a>”. Having read both blog posts, you should have almost all the core knowledge needed to deeply understand how LLMs are being built nowadays, just missing a bit the final spices like data mixing or architecture choices to complete the recipe (stay tuned…).</p> | |
| <p>Pre-training LLMs from scratch now requires amounts of compute which exceed in almost every case the use of a single GPU or machine. The clusters used to train these models range from hundreds to thousands of nodes each usually equipped with 4 to 8 GPUs. To make the best use of such an expensive hardware as well as to train in a reasonable time, a range of distributed training methods have been developed with the goal of ensuring that GPUs are highly utilized at all times. Efficiently scaling LLM training is also not confined to pretraining anymore, as fine-tuning larger models on more domain specific data is becoming the standard practice to achieve the best results.</p> | |
| <p>In this post we’ll cover these scaling methods exhaustively while keeping a single story-line to understand where each technique comes from. We’ll cover data, tensor, pipeline and context parallelism as well as ZeRO and kernel fusion. The post is built on the following <strong>three foundations</strong>:</p> | |
| <p><strong>Quick intros on theory and concepts:</strong> before diving into code and experiments, we want to understand how each method works at a high level and what it’s advantages and limits are. You’ll learn about which parts of a language model eat away your memory and when during training it happens. You’ll learn how we can solve memory constraints by parallelizing the models and increase the throughput by scaling up GPUs.</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image.png" /></p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%201.png" /></p> | |
| <p><strong>Clear code implementations:</strong> theory is one thing, but we discover all kinds of edge cases and important details when we implement something. That’s why we link to implementation references where possible. Depending on the case, we’ll use two code references: the <a href="https://github.com/huggingface/picotron">picotron</a> repository is built for education, thus it implements concepts usually in single, self-contained short files. On the other hand, to look at production ready code, we’ll refer to the <a href="https://github.com/huggingface/nanotron">nanotron</a> implementations which is a production training codebase used at Hugging Face.</p> | |
| <p><img alt="Picotron implements each key concept in a self-contained way, such that the method can be studied separately and in isolation." src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%202.png" /></p> | |
| <p>Picotron implements each key concept in a self-contained way, such that the method can be studied separately and in isolation.</p> | |
| <p><strong>Real training efficiency benchmarks:</strong> Finally, how to <em>actually</em> scale your LLM training depends on your infrastructure, such as the kind of chips, interconnect etc., and we can’t give a single unified recipe. What we will give though is a way to benchmark several setups and it is what we have done on our cluster! We ran over 4100 distributed experiments with up to 512 GPUs to scan many possible distributed training layouts and model sizes. TODO: link to dataset too </p> | |
| <p><img alt="An overview of the over 4000 experiments across all Llama architectures where each data point corresponds to an experiment launch." src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%203.png" /></p> | |
| <p>An overview of the over 4000 experiments across all Llama architectures where each data point corresponds to an experiment launch.</p> | |
| <p>As you can see, there’s a lot of ground to be covered. Before getting into the trenches of distributed training let’s take a quick high level look on we’ll cover in the post.</p> | |
| <h1>TL;DR</h1> | |
| <p>This book is very extensive so we decide to start with a very general overview of how you can think about distributed training. At a high level, the key challenge in scaling LLM training is to make a training step (forward/backward/optimizer step) with a large batch size the fastest possible.</p> | |
| <p>When scaling up models and input batches, we quickly end up in situations where either our target batch size won't fit in memory, or/and the model itself is too large to fit in a single GPU's memory.</p> | |
| <p>To solve this scaling issue we’ll need to carefully evaluate different parallelization strategies and find the optimal balance between three main factors:</p> | |
| <ol> | |
| <li><strong>Memory Usage</strong><ul> | |
| <li>Hard limitation - if a training step doesn't fit in memory, training cannot proceed</li> | |
| <li>Sometimes we can increase compute (e.g. recomputation) or increase communication (e.g. ZeRO) to reduce memory usage</li> | |
| </ul> | |
| </li> | |
| <li><strong>Compute Efficiency</strong><ul> | |
| <li>Memory transfer can also decrease compute efficiency.</li> | |
| <li>We want our hardware to spend most time computing, so we need to reduce time spent on data transfers or unoptimized kernels.</li> | |
| <li>GPUs need sufficient workload (large enough matrices/batch sizes) to maintain high utilization (compute-bound) otherwise they become memory-bound (limited by memory bandwidth).</li> | |
| </ul> | |
| </li> | |
| <li><strong>Communication overhead</strong><ul> | |
| <li>Two main types. For GPUs: intra-node (NVLink TODO: bandwidth) and inter-node (network TODO: bandwidth)</li> | |
| <li>Two main attributes: base latency and peak bandwidth. Base latency is a constant overhead that makes us want to do the least number of comms possible, and peak bandwidth controls the how fast we can move data between gpus</li> | |
| <li>We prioritize using the fastest communication channels (like NVLink) for operations that occur frequently and/or block computation (e.g. tensor parallelism)</li> | |
| <li>We want to minimize communication overhead as it keeps GPUs idle, so we try to overlap communication with compute as much as possible</li> | |
| </ul> | |
| </li> | |
| </ol> | |
| <p>But let’s not get too much ahead of our self and scale progressively. To guide you along the journey and as a practical reference we summarized the key concepts in a cheatsheet:</p> | |
| <p>[TODO: ADD CHEATSHEET]</p> | |
| <p>Now that we nailed a few key concept and terms let’s get started by revisiting the basic training steps of an LLM!</p> | |
| <h1>First Steps: Training on one GPU</h1> | |
| <p>Let’s start by quickly reviewing the very basics of model training before we start to scale to many GPUs. When a model is trained on a single GPU, the training typically consists of three steps: </p> | |
| <ol> | |
| <li>a forward pass which passes inputs through the model to yield its outputs,</li> | |
| <li>a backward pass to compute the gradients, and</li> | |
| <li>an optimization step using the gradients to update the parameters</li> | |
| </ol> | |
| <p>It looks generally like this: </p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%204.png" /></p> | |
| <blockquote> | |
| <p>Note: As we’ll see later, these steps may be repeated or intertwined but for now we’ll start simple. | |
| </p> | |
| </blockquote> | |
| <p>In this figure, the boxes on the top line can be seen as successive layers inside a model (same for the last line). The red boxes are the associated gradients for each of these layers, computed during the backward pass.</p> | |
| <p>The batch size ($bs$) is one of the important hyper-parameters for model training and affects both model convergence and throughput.</p> | |
| <p>If the batch size is too small, gradients will tend to be noisy and the model may not be able to converge to the most optimal performance, on the contrary it can be useful in early training to navigate quickly in the training landscape. On the other hand, a batch size too large will make less use of each training token rendering convergence slower and wasting compute. You can find a nice discussion of this topic in <a href="https://arxiv.org/abs/1812.06162">OpenAI’s paper on large batch training</a> or Section 4.2 of MiniMax-01 <a href="https://filecdn.minimax.chat/_Arxiv_MiniMax_01_Report.pdf">technical report</a>.</p> | |
| <blockquote> | |
| <p>Note: For instance, during DeepSeek-V3/R1 training “the batch size is gradually increased from 3072 to 15360 in the training of the first 469B tokens, and then keeps 15360 in the remaining training” | |
| </p> | |
| </blockquote> | |
| <p>Batch size also affects the time it takes to train on a given text dataset: a small batch size will require more optimizer steps to train on the same amount of samples. Optimizer steps are costly (in compute time) and the total time to train will thus increase compared to a larger batch size. This being said, note that the batch size can often be adjusted quite largely around the optimal batch size without major impact to the performance of the model, i.e. the sensitivity of final model performances to the exact batch size value is usually rather low around the optimal batch size.</p> | |
| <p>In the LLM pretraining community, batch sizes are commonly reported in terms of tokens rather than in number of samples ($BST$ = Batch Size Tokens), this makes training numbers generally independent of the exact input sequence length used during the training.</p> | |
| <p>In the simplest case, training on a single machine, the $BS$ (in samples) and $BST$ can be computed from the model input sequence length (seq) as follows :</p> | |
| <p>$$ | |
| bst=bs *seq | |
| $$</p> | |
| <blockquote> | |
| <p>Note: From here onward we’ll show the formulas for the batch size in terms of samples but you can always get its token-unit counterpart by multiplying it with the sequence length. | |
| </p> | |
| </blockquote> | |
| <p>A sweet spot for recent LLM training is typically on the order of 4-60 million tokens per batch. However, a typical issue when scaling the training of our model to these large batch sizes is out-of-memory issues, ie. our GPU doesn’t have enough memory.</p> | |
| <blockquote> | |
| <p>Note: Llama 1 was trained with a batch size of ~4M tokens for 1.4 trillions tokens while DeepSeek was trained with a batch size of ~60M tokens for 14 trillion tokens. | |
| </p> | |
| </blockquote> | |
| <p><strong>It’s time to tackle our first scaling problem: what if our model starts exploding GPU memory before we’ve reached our target batch size (maybe in some case even when using the lowest possible batch size, <code>BS=1</code>)?</strong></p> | |
| <p>Let’s start by quickly understanding what led to our out-of-memory issue in the first place. This will help us gain some useful intuitions for later.</p> | |
| <h2>Memory usage in Transformers</h2> | |
| <p>When training a neural network model, one store several items in memory:</p> | |
| <ul> | |
| <li>Model weights</li> | |
| <li>Activations needed to compute the gradients</li> | |
| <li>Model gradients</li> | |
| <li>Optimizer states</li> | |
| </ul> | |
| <blockquote> | |
| <p>You would think for a model you could compute the memory requirements exactly but there are a few additional memory occupants that makes it hard to be exact: | |
| - CUDA Kernels typically require 1-2 GB of GPU memory, which you can quickly verify by running <code>import torch; torch.ones((1, 1)).to("cuda")</code> and then checking the GPU memory with <code>nvidia-smi</code>. | |
| - Some rest memory usage from buffers, intermediate results and some memory that can’t be used due to fragmentation | |
| We’ll neglect these last two contributors as they are typically small and constant factors. | |
| </p> | |
| </blockquote> | |
| <p>These items are stored as tensors which come in different <em>shapes</em> and <em>precisions</em>. The <em>shapes</em> are determined by hyper-parameters such as batch size, sequence length, model hidden dimensions, attention heads, vocabulary size, and potential model sharding as we’ll see later. <em>Precision</em> refers to formats like FP32, BF16, or FP8, which respectively require 4, 2, or 1 byte to store each single value in the tensor.</p> | |
| <p>So how can I quickly determine memory usage from these variable? One simple way is to do this empirically and just measure it.</p> | |
| <h3>Memory profiling a training step</h3> | |
| <p>Using this snippet [TODO: link to appendix A5], we can understand how memory is allocated throughout training. We can see that memory utilization is not a static thing but varies a lot during training and during a training step:</p> | |
| <p><img alt="llama-1b-memory.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/44e32d18-4ed6-455a-a1f7-bbbdebe2fefd.png" /></p> | |
| <p>Clearly the first step looks very different from the subsequent ones, but let’s first have a look at the general anatomy of a step: first the activations increase quickly as we do the forward pass, then during the backward pass the gradients build up and as the backward pass propagates, the stored activations used to compute the gradients are progressively cleared. Finally, we perform the optimization step during which we need all the gradients and then update the optimizer states before we start the next forward pass. </p> | |
| <p>Why does the first step looks different: the activations increase quickly and then plateau for a while. In this first step the torch cache allocator does a lot of preparation preparing memory allocations to speed up the subsequent steps so that they don’t require searching for free memory blocks afterwards (see <a href="https://zdevito.github.io/2022/08/04/cuda-caching-allocator.html">Zach’s blog</a>). After the first step we also see the optimizer states appearing which generally offset the memory usage for further training steps. </p> | |
| <blockquote> | |
| <p>Ever noticed how sometimes the training succeeds in the first step but then OOMs during the following training steps? This can be explained by the build-up of the optimizer state after the first step. | |
| </p> | |
| </blockquote> | |
| <p>Now that we’ve a first view of memory, let’s see how scaling up training is often a question of maximizing compute efficiency while keeping the memory requirements of these various items (activations, parameters, gradients, optimizer states) within the memory constraints of the GPUs.</p> | |
| <h3>Weights/grads/optimizer states memory</h3> | |
| <p>We can actually pretty easily estimate the memory needed for the model’s weights, gradients and optimizer states.</p> | |
| <p>For a simple transformer LLM the number of parameters is given by the <a href="https://michaelwornow.net/2024/01/18/counting-params-in-transformer">following formula</a>:</p> | |
| <p>$$</p> | |
| <p>N = h<em>v + L * (12 * h^2 + 13</em>h) + 2*h | |
| $$</p> | |
| <blockquote> | |
| <p>Note: we excluded the positional embedding count as rotary embeddings are not learned. | |
| </p> | |
| </blockquote> | |
| <p>In that equation, $h$ is the hidden dimension, $v$ the vocabulary size, and $L$ the number of layers in the model. Note that looking at the equation we can see that the term that will dominate at large hidden dimensions is the $h^2$ term since it’s the only one growing quadratically as we scale the parameters.</p> | |
| <p>Memory requirements for the parameters and gradients are simply the number of parameters multiplied by the number of bytes per parameter. In good old-fashioned full precision (FP32) training both parameters and gradients require 4 bytes while the optimizer, if we use Adam, requires the momentum and variance to be stored, which adds another two 4 bytes per parameter. In summary:</p> | |
| <p>$$ | |
| m_{params} = 4 * N \ | |
| m_{grad} = 4 * N \ | |
| m_{opt} = (4+4) * N | |
| $$</p> | |
| <p>Now let’s have look how things change if we train with <a href="https://arxiv.org/abs/1710.03740">mixed precision</a>. The default nowadays is for mixed precision training is BF16, requires 2 bytes per parameter and gradient as well as an additional copy of the model weights and gradients in FP32, thus 12 bytes per parameter in total. In addition to the parameters and gradient, we need to store the optimizer states: for the Adam optimizer, this requires the momentum and the variance usually stored in FP32 for numerical stability, each using 4 bytes. </p> | |
| <blockquote> | |
| <p>Note: See some more details below when we cover the <a href="https://arxiv.org/pdf/1910.02054">ZeRO</a> work. | |
| </p> | |
| </blockquote> | |
| <p>Here’s the summary:</p> | |
| <p>$$ | |
| m_{params} = 2 * N \ | |
| m_{grad} = 2 * N \ | |
| m_{params_fp32} = 4 * N \ | |
| m_{opt} = (4+4) * N | |
| $$</p> | |
| <blockquote> | |
| <p>Some librarie store grads in fp32 which would require an additional $m_{params_fp32} = 4 * N$ memory. This is done for example in nanotron, because <code>bf16</code> is lossy for smaller values and we always prioritize stability. See https://github.com/microsoft/DeepSpeed/issues/1773 for more information. | |
| </p> | |
| </blockquote> | |
| <p>Interestingly, mixed precision itself doesn’t save overall memory as it just distributes the memory differently across the three components, and in fact adds another 4 bytes over full precision training if we accumulate gradients in FP32. It’s still advantageous as having the model which does the forward/backward in half precision it allows us to (1) use optimized lower precision operations on the GPU which are faster and (2) reduces the activation memory requirements during the forward pass.</p> | |
| <p>Let’s get a sense of how much general memory we need for a model (full and mixed precision giving the same overall value):</p> | |
| <p>| <strong>Model parameters</strong> | <strong>FP32 or BF16 w/o FP32 grad acc</strong> | <strong>BF16 w/ FP32 grad acc</strong> | | |
| | --- | --- | --- | | |
| | 1B | 16 GB | 20 GB | | |
| | 7B | 112 GB | 140 GB | | |
| | 70B | 1120 GB | 1400 GB | | |
| | 405B | 6480 GB | 8100 GB |</p> | |
| <blockquote> | |
| <p>Using FP8 training instead of BF16 would further decrease the memory usage but it is less stable and a very active research topic (see [<a href="https://x.com/xariusrke/status/1826669126955278401">ref</a>]) and we’ll cover it in more detail later. | |
| </p> | |
| </blockquote> | |
| <p>As we can see, as soon as we reach <strong>7B</strong> (!), weights and optimizer requirements already starts to add up significantly and exceed the size of a typical GPU memory, e.g. 80GB for a H100 GPU.</p> | |
| <p>But for now, let’s start with models which still fits in a single GPU, take a look at the other big contributor to our memory budget: the activation memory.</p> | |
| <h3>Activations memory</h3> | |
| <p>Activation memory is a bit more complex to compute than the weights, gradients and optimizer states, in part because it depends on the inputs of the model. If you’re unsure why we even need to store activations for the backward pass, <a href="https://www.determined.ai/blog/act-mem-2">this reference</a> is a good quick refresh. After a careful inspection of how backward pass is computed we can estimate the total memory required for the activations in mixed precision and we arrive at the following equation:</p> | |
| <p>$$ | |
| m_{act} = L<em> seq * bs * h * (34 + \frac{5</em>n_{heads}*seq}{h})</p> | |
| <p>$$</p> | |
| <p>Here L is the number of layers, $seq$ the sequence length, $bs$ the batch size in samples, $h$ the hidden dimension of the model and $n_{heads}$ the number of heads.</p> | |
| <p>For the exact numbers derivation, you can follow this <a href="https://arxiv.org/pdf/2205.05198">NVIDIA pape</a>r on recomputation, it essentially requires you to do some accounting of all the sizes of intermediate activations between each operation.</p> | |
| <p>An interesting observation here is how the memory is not static for a given model but it scales linearly with both the sequence length and batch size. This means the activation memory is the part which will blow up when we increase our batch size or train with longer sequences. We can use this equation to look at how memory usage changes for various sequence lengths for example for Llama models (<code>bs=1</code>):</p> | |
| <p><img alt="llama-memory-bars-no-recomp.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/llama-memory-bars-no-recomp.png" /></p> | |
| <p>This graph tells a striking story: for short sequences (or similar for small batch-sizes), activations are almost negligible, but starting at around 2-4k tokens they come to take a significant amount of memory while parameter, gradient and optimizer states usage (that we’ll discuss later) stays roughly independent of the sequence length and batch size.</p> | |
| <p><strong>For large input tokens (a.k.a large batch-sizes/sequences), activations become by far the largest memory burden.</strong> </p> | |
| <p>Is there a way to tame this “activation explosion”? Good question, reader!</p> | |
| <p>It’s time to explain our first technique – called <strong><em>activation recomputation</em><em>–</em> </strong>**which will help us cap activation memory footprint. An essential tool in today’s large model training toolbox.</p> | |
| <h2><strong>Activation recomputation</strong></h2> | |
| <p>The general idea behind <strong><em>activation recomputation</em><em> –</em>also called </strong><em>gradient checkpointing</em><strong> or </strong><em>rematerialization</em><em>– </em>****is to discard some activations during the forward pass to save memory and spend some extra compute to recompute these on the fly during the backward pass. Without recomputation, we store every hidden state between two learnable operations (e.g. FF, LayerNorm etc.), such that we can use them during the backward pass to compute gradients. When we use recomputation we typically will only store activations at a few key points along the model architecture, discard the rest of activations and recompute them on the fly during the backward pass from the nearest saved activations, basically performing again a sub-part of the forward pass to trade of memory for compute. It generally looks like this:</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%205.png" /></p> | |
| <p>There are several strategies to select key activations to store:</p> | |
| <ul> | |
| <li><strong>Full</strong>: We checkpoint activations at the transition point between each layer of the Transformer model. This is usually called the <code>full</code> strategy since it requires a forward pass through each layer essentially adding a full forward pass during the backward pass. This strategy saves the most memory but is the most expensive one in terms of compute. It generally increases the compute cost and time by up to 30-40% which is very noticeable.</li> | |
| <li><strong>Selective</strong>: In general we can do better than full. The authors of the <a href="https://arxiv.org/pdf/2205.05198">recomputation paper</a> did a detailed analysis studying which activations grow the largest and have the cheapest recomputation cost in terms of FLOPs. Turns out that the attention computations fall in that category, and thus we can usually discard them and focus on checkpointing expensive the feedforward computations. For a GPT-3 (175B) model this means <strong>70% activation memory reduction at a 2.7% compute cost</strong>.</li> | |
| </ul> | |
| <blockquote> | |
| <p>Note: In recent models like DeepSeek V3, selective checkpointing is performed, storing even a smaller size of attention activation —using so-called “Multi-Head Latent Attention” (MLA)– to optimize activation memory usage. | |
| </p> | |
| </blockquote> | |
| <p>Let’s see how drastically recomputation strategies can in practice reduce the memory footprint and how selective recomputation strikes a nice balance between memory saving and recomputation cost:</p> | |
| <p><img alt="llama-8b-memory-bars--recomp.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/llama-8b-memory-bars--recomp.png" /></p> | |
| <blockquote> | |
| <p>When you’re measuring how efficient your training setup is at using the accelerator’s available compute, you may want to take recomputation into account when measuring the total FLOPS (Floating point operations per second) of your training setup and comparing it to theoretical maximum FLOPS of your GPU/TPU/accelerator to estimate GPU utilization. Taking recomputation into account when calculating FLOPS for a training step gives a value called “hardware FLOPS” which is the real number of operations performed on the accelerator. Dividing this number by the duration of one training step and the maximum accelerator FLOPS yields the <em>Hardware FLOPS Utilization (HFU).</em></p> | |
| </blockquote> | |
| <p>However, when comparing various accelerators together, what really matters at the end of the day is the start-to-end time needed to train the same models on the same dataset, ie. if an accelerator allows to skip recomputation and thus perform less operation per second for a faster training it should be rewarded. Thus, alternative is to compute what is called <em>Model FLOPS Utilization (MFU)</em>, which in contrast to HFU only accounts for the required operations to compute the forward+backward passes, and not recomputation, ie. is specific to the model, not the training implementation.</p> | |
| <blockquote></blockquote> | |
| <p>Most training frameworks these days use FlashAttention (which we’ll cover a bit later) which integrate natively activation recomputation in its optimization strategy by recomputing attention scores and matrices in the backward pass instead of storing them. Thus most people using Flash Attention are already making use of selective recomputation.</p> | |
| <p><strong>As you’ve now understood, activation recomputation increases the number of FLOPs slightly due to recomputation, while it significantly reduces memory access overhead.</strong> </p> | |
| <p>This trade-off is particularly advantageous on hardware with small high-speed memory, like GPUs, as accessing memory is typically slower than performing computations. Despite the additional operations involves, the overall effect is thus often faster computation as well, in addition to the much lower memory footprint.</p> | |
| <p>Now that we’ve learned about recomputation, we can tame the activations memory usage as we saw in the above graphs!</p> | |
| <p>However, activations still bears a linear dependance on the batch size and all our profiles in the barplots above were using <code>bs=1</code> so as we move to larger batch sizes it might become an issue again. Do not despair as we have a second tool in our box - <strong><em>gradient accumulation</em></strong> to the rescue!</p> | |
| <h2>Gradient accumulation</h2> | |
| <p>Now that we’ve used activation recomputation to fit our model with a small batch size on a single GPU, we still need to reach our target batch size, let’s say 1M tokens (see our earlier discussion on optimal batch size). Gradient accumulation is a very straightforward method to avoid memory explosion when doing this.</p> | |
| <p>With <em>gradient accumulation</em> we split our batch into micro-batches, do forward and backward passes repeatedly on each micro-batch, compute the gradients, and, as the name suggests, sum the gradients for each micro-batch before doing a final optimizer step. In practice, we perform the optimization step not on the sum but on the average of the gradients, so the result is independent of the number of gradient accumulation steps.</p> | |
| <p>Let’s call the batch size for each forward pass the <code>micro batch size</code> (MBS). We’ll refer to the overall batch size between each optimizer step as the <code>global batch size</code> (GBS). If we do one optimizer step for each 8 forward/backward passes, the <code>global batch size</code> will be 8 times the <code>micro batch size</code>.</p> | |
| <p>What we now call <code>global batch size</code> thus corresponds to what we’ve called up to now just <code>batch size</code> for simplicity (we now make our terms more precise to avoid ambiguity).</p> | |
| <p>With gradient accumulation the global batch size can be simply computed as follows:</p> | |
| <p>$$ | |
| bs = gbs = mbs \times grad_acc | |
| $$</p> | |
| <p>Gradient accumulation allows us to effectively increase our batch size up to infinity (and beyond!) while the memory footprint stays constant. Gradient accumulation is also compatible with activation recomputation for further memory reduction. One drawback however, is that gradient accumulation requires multiple consecutive forward/backward passes per optimization step thereby increasing the compute overhead and slowing down training. No free lunch! </p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%206.png" /></p> | |
| <p><strong>Gradient accumulation allows us to reduce memory of activations which grow linearly with batch size by computing only only partial, micro-batches. There is a small overhead caused by the additional forward and backward passes.</strong></p> | |
| <blockquote> | |
| <p>Note: Using gradient accumulation means we need to keep buffers where we accumulate gradients which persist throughout a training step. Whereas without gradient accumulation, in the backward gradients are computed while freeing the activations memory, which means a lower peak memory. | |
| </p> | |
| </blockquote> | |
| <p>But if you’ve carefully followed, you probably noticed that the forward/backward passes for each micro-batch can actually be run in parallel. Forward/backward passes are independent from each other, with independent input samples being the only difference. Seems like it’s time to start extending our training to more than one GPU! </p> | |
| <p>Let’s get a larger workstation 🖥️ with a couple of GPUs and start investigating our first scaling technique called <em><strong>data parallelism</strong> which is just a parallel version of gradient accumulation</em>.</p> | |
| <p>TODO: intro for this</p> | |
| <h2>torch.profiler</h2> | |
| <p><img alt="**Overlapped backward pass (stream 7) and gradients accumulation (stream 28) means we start the optimizer step as soon as the backward pass is done**" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%207.png" /></p> | |
| <p><strong>Overlapped backward pass (stream 7) and gradients accumulation (stream 28) means we start the optimizer step as soon as the backward pass is done</strong></p> | |
| <p><img alt="In this naive approach we see a long AllReduce operation (stream 28) happening to sync all gradients after the backward pass is over, after which comes the optimizer step. GPUs stay idle while the communication happens." src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%208.png" /></p> | |
| <p>In this naive approach we see a long AllReduce operation (stream 28) happening to sync all gradients after the backward pass is over, after which comes the optimizer step. GPUs stay idle while the communication happens.</p> | |
| <p><img alt="**Overlapped backward pass (stream 7) and gradients accumulation (stream 28) means we start the optimizer step as soon as the backward pass is done**" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%207.png" /></p> | |
| <p><strong>Overlapped backward pass (stream 7) and gradients accumulation (stream 28) means we start the optimizer step as soon as the backward pass is done</strong></p> | |
| <h1>Data Parallelism</h1> | |
| <p>The idea behind data parallelism (DP) is to replicate the model on several GPUs (we call the replica's “model instances”) and run forward and backward passes on different micro batches of data in parallel for each GPU, hence the name Data Parallelism. </p> | |
| <p>Using a different micro batch for each GPU means we’ll have different gradients in each GPU, so to keep the model instances in sync across different GPUs, the gradients from the model instances are averaged using an operation called “all-reduce”, which happens during the backward pass, before the optimizer step.</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%209.png" /></p> | |
| <p>This involves our first “distributed communication” primitive: <a href="https://www.notion.so/The-Ultra-Scale-Playbook-Training-LLMs-on-GPU-Clusters-af1b4137215e4e4eb1971e7dfa3185a9?pvs=21">**All-Reduce</a>** which handles the synchronization and communication between GPU instances and nodes.</p> | |
| <blockquote> | |
| <p>If you are not familiar with distributed communications patterns like broadcast, gather or all-reduce we put together a small crash course in the Appendix [TODO Link]. | |
| </p> | |
| </blockquote> | |
| <p>TODO: bucket grads to avoid multiple comms | |
| TODO: show comms overlap</p> | |
| <p>TODO: any comms requires at least a contiguous buffer to do comms → TIP: make sure tensors that’ll be communicated are contiguous in memory to avoid redundant memory copies</p> | |
| <p>TODO: embed naive DP: <a href="https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/data_parallel/data_parallel.py#L10-L60">https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/data_parallel/data_parallel.py#L10-L60</a></p> | |
| <p>TODO: embed bucket DP: <a href="https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/data_parallel/data_parallel.py#L62-L171">https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/data_parallel/data_parallel.py#L62-L171</a></p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2010.png" /></p> | |
| <p>A naive DP implementation would just wait for the backward pass the finish so that we have all gradients, then it triggers an all-reduce over all DP ranks, to sync these gradients. But such an sequential steps of computation followed by communication is <strong>A BIG NO!</strong> Because we don’t want our GPUs to stay idle while communication is happening.</p> | |
| <p>Instead we should try to overlap communication and computation whenever possible so that they happen at the same time as much as possible.</p> | |
| <p>Let’s see three optimizations that are done in practice for this! </p> | |
| <h3><strong>First optimization:</strong> Overlap gradient synchronization with backward pass</h3> | |
| <p>The main drawback of the naive DDP approach we’ve just described is that after the backward pass (<em>computation</em>), we have to wait for gradient synchronization (<em>communication</em>) before updating the parameters. Could we overlap this communication with our computation? The answer is yes!</p> | |
| <p>As shown in the figure above, the gradients (red boxes) for a layer can be gathered and summed even before the gradients from earlier layers (red boxes to the left) have been computed. For example, as soon as the backward pass of the last layer is complete (last box on the right), those gradients can already be gathered and summed while the backward computations continue for earlier layers, moving toward the left.</p> | |
| <p>This can be achieved in pytorch by attaching an <em>all-reduce hook</em> <em>function</em> to each parameter. An all-reduce operation is triggered as soon as the gradient for that parameter is ready, while gradients for other parameters are still being computed. This approach overlaps most of the all-reduce operations with gradient calculations, thereby improving efficiency.</p> | |
| <p><code>python | |
| def register_backward_hook(self, hook): | |
| """ | |
| Registers a backward hook for all parameters of the model that | |
| require gradients. | |
| """ | |
| for p in self.module.parameters(): | |
| if p.requires_grad is True: | |
| p.register_post_accumulate_grad_hook(hook)</code></p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2011.png" /></p> | |
| <p>Overlapping computation and communication reduces the time spent waiting for gradient synchronization across the entire model. Gradient synchronization can occur (at least partially) in parallel with backward pass, significantly speeding up data parallelism. </p> | |
| <p>This is our first example of “<em>overlapping computation and communication</em>” which we will discuss several times in this blog post and is an essential technique to maximal scaling efficiency.</p> | |
| <h3><strong>Second optimization:</strong> Bucketing gradients</h3> | |
| <p>But we can even go further. For a given number of parameters to synchronize, GPU operations like collective communications are often more efficient when performing few calls on large tensors rather than many calls on smaller tensors. Therefore, instead of performing independent all-reduce for each gradient, we can group gradients into buckets and launch a single all-reduce for all the gradients within the same bucket. Think of it like packing items into boxes before shipping—it's more efficient to send a few big boxes than many small ones. By performing a single all-reduce operation for each bucket, we can significantly reduce communication overhead and speed up the communication operation.</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2012.png" /></p> | |
| <p>The selected bucket size will be a key factor in determining the efficiency of Data Parallelism, yet it is <strong>often overlooked</strong>. For instance, the default bucket size of <strong>25</strong> MB in PyTorch DDP is becoming quickly outdated given the rapid evolution of model sizes and hardware. Improvements in network bandwidth, GPU memory, and faster interconnects like NVLink or Infiniband mean that larger bucket sizes can be handled more efficiently., and increasing the bucket size can often result in better performance, for example, setting it to <strong>100</strong> MB. The trade-off is that larger bucket sizes might introduce latency or require more memory, so tuning this parameter can be specific to the hardware and model being used.</p> | |
| <p>[TODO: benchmark all reduce with different size / bucket size results ?]</p> | |
| <h3><strong>Third optimization: I</strong>nterplay with gradient accumulation</h3> | |
| <p>As we’ve seen before, gradient accumulation works by performing multiple forward and backward passes before updating the parameters with <code>optimizer.step()</code>. When combining gradient accumulation with data parallelism, we should be careful when we want to synchronize gradients.</p> | |
| <p>In a naive version, an all-reduce operation is automatically triggered after each backward pass during the accumulation, which is sub-optimal as a single reduce after the final step would have the same effect while reducing overhead.</p> | |
| <p>In PyTorch, this is typically solved by adding a <a href="https://github.com/pytorch/pytorch/blob/5ea67778619c31b13644914deef709199052ee55/torch/nn/parallel/distributed.py#L1408-L1435"><code>model.no_sync()</code></a> decorator, which disables gradient synchronization, on the backward passes which don’t need reduction.</p> | |
| <h2>Revisit global batch size</h2> | |
| <p>Let’s update our batch size equation with our newly learned Data Parallelism and Gradient Accumulation parameters:</p> | |
| <p>$$ | |
| gbs=mbs \times grad_acc \times dp | |
| $$</p> | |
| <p>Where $grad_acc$ is the number of gradient accumulation steps and DP is the number of parallel instances used for data parallelism.</p> | |
| <p>Given a targeted global batch size, we can thus trade gradient accumulation steps for data-parallel processes to speed up training. In practice, people tend to maximize the number of data-parallel nodes (DP) over gradient accumulation as much as possible since it's inherently parallel, unlike the sequential nature of gradient accumulation. Gradient accumulation is then added on top of data parallelism to achieve the target global batch size when scaling data parallelism alone is not sufficient before you run out of GPUs.</p> | |
| <blockquote> | |
| <p>A good resource for further reading on Data Parallelism is <a href="https://siboehm.com/articles/22/data-parallel-training">https://siboehm.com/articles/22/data-parallel-training</a>. | |
| </p> | |
| </blockquote> | |
| <p>Being able to distribute the training over different samples gives us a first dimension of parallelization, thus making this 1D parallelism (we’ll progressively cover 3 more dimensions).</p> | |
| <h2>Our journey up to now</h2> | |
| <p>Let’s quickly summarize what we’ve seen up to now and how to setup our first 1D parallel training with a draft recipe for an optimal data-parallel setup:</p> | |
| <ol> | |
| <li>We should first determine the best (global) batch size in tokens (<code>GBST</code>) either by consulting literature or running experiments measuring model convergence.</li> | |
| <li>We then select a sequence length for training, again by either consulting literature or running experiments. Generally, 2-8k tokens work reliably well for the evaluations we have today (we won’t dive in training recipes here but teams usually increase the sequence at the end of the training, adding some longer-context data samples in the mix to reach the longer context size of today).</li> | |
| <li>We now know the batch size (gbs). We can find the maximum local batch size (mbs) on a single GPU by increasing the local batch size until we run out of memory.</li> | |
| <li>Finally, we determine the number of available GPUs for our target DP. The ratio of GBS to DP gives us the remaining number of gradient accumulation steps needed for the desired GBS. </li> | |
| </ol> | |
| <blockquote> | |
| <p>Note: For instance DeepSeek and Llama models are trained with a 4k tokens sequence length during the main pretraining phase. | |
| </p> | |
| <p>Note: The reason 2-8k work well for pretraining is that documents that are longer are very rare on the web. See this <a href="https://www.harmdevries.com/post/context-length/">Harm’s blogpost</a> for a detailed analysis. | |
| </p> | |
| </blockquote> | |
| <p>If the gradient accumulation ratio is lower than one, i.e. we have too many GPUs a.k.a GPU-rich 🤑 (!), we can either choose to not use all our GPUs, explore a larger global batch size or test if a lower MBS will speed up training. In the latter case we’ll end up prioritizing throughput over individual GPU compute efficiency, using a smaller MBS than possible in order to speed up training.</p> | |
| <p>Time to take a concrete example: Let’s say we want to train a recent model with a GBS of 4M tokens and a sequence length of 4k. This means our batch size will be 1024 samples (we pick powers of two). We observe that a single GPU can only fit MBS=2 in memory and we have 128 GPUs available for training. This means with 4 gradient accumulation steps we’ll achieve our goal of 1024 samples or 4M tokens per training step. Now what if we suddenly have 512 GPUs available? We can achieve the same GBS and thus identical training by keeping MBS=2 and setting gradient accumulation steps to 1 and achieve faster training!</p> | |
| <blockquote> | |
| <p>Bear in mind that at the 512GPUs scale, depending on the network used, the communication operations will start to be bound by ring latency which means we can no longer fully overlap the DP communications. This will decrease our compute efficiency and hit our throughput. In this case we should start exploring other dimensions to parallelize on. | |
| </p> | |
| </blockquote> | |
| <p>TODO: We’re gaining overall throughput but losing efficiency as we scale DP too much</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2013.png" /></p> | |
| <p><strong>We’ve explored data parallelism, our first (simple) strategy to scale training across more GPUs. It works like gradient accumulation but parallelizes the forward and backward passes on micro batches, thus increasing throughput!</strong></p> | |
| <p>The keen reader have already probably notes however that this assumes that we can fit at least one input sample forward pass (mbs<em>=1)</em> into our GPU memory (with activation recomputation if needed).</p> | |
| <p>This is not always the case! As we’ve seen earlier larger models often don’t fit into a single GPU, even with activation recomputations activated. </p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2014.png" /></p> | |
| <p>Do we have other options for these larger models? We do have some solutions thankfully. They will involve either move some of these tensors to the CPU or split the weights/gradients/optimizer-states tensors across GPUs devices!</p> | |
| <p>There are two main approaches to splitting: parallelism (tensor, context, or pipeline parallelism) and sharing (DeepSpeed Zero or PyTorch FSDP). Both approaches are somewhat orthogonal and can actually be combined! The sharing paradigm is closely related to DP so we’ll have a look at it first by investigating the ZeRO method!</p> | |
| <h2>ZeRO (<strong>Ze</strong>ro <strong>R</strong>edundancy <strong>O</strong>ptimizer)</h2> | |
| <p>In this section we will introduce DeepSpeed ZeRO (<strong>Ze</strong>ro <strong>R</strong>edundancy <strong>O</strong>ptimizer), a memory optimization technology designed to reduce memory redundancies in LLM training.</p> | |
| <blockquote> | |
| <p>Note: We’ll focus on ZeRO-1 to ZeRO-3 in this blog as it should give a broad view on how it helps reduce memory while showing the tradeoffs to take into account. You can find more ZeRO flavors in the <a href="https://www.deepspeed.ai/tutorials/zero/">Deepspeed docs</a>. | |
| </p> | |
| </blockquote> | |
| <p>While Data Parallelism is very efficient in scaling training, the naive replication of optimizer states, gradients, and parameters across each DP rank introduces a significant memory redundancy. ZeRO eliminates memory redundancy by partitioning the optimizer states, gradients, and parameters across the data parallel dimension, while still allowing computation with the full set of parameters. This sometimes requires more communications between DP ranks which may or may not be fully overlapped as we’ll see next!</p> | |
| <p>This approach is organized into three possible optimization stage of ZeRO:</p> | |
| <p>ZeRO-1: optimizer state partitioning</p> | |
| <p>ZeRO-2: optimizer state + gradient partitioning</p> | |
| <p>ZeRO-3 (also called FSDP for “Fully-Sharded Data Parallelism”): optimizer state + gradient + parameter partitioning</p> | |
| <blockquote> | |
| <p>Note: You might be missing the activations among the things we can shard. Since each DP replica of the model receives a different microbatch the activations on each DP rank also differ so they are not duplicated and thus can’t be sharded! | |
| </p> | |
| </blockquote> | |
| <p>Let’s have a closer look how much we can save with the partitioning of each ZeRO stage!</p> | |
| <h3>Memory usage revisited</h3> | |
| <p>Let’s first recap the memory usage of optimizer states, gradients, and parameters during a standard training. Let’s define the number of our model's parameters as $\Psi$ (previously N but here we use the original ZeRO notation). In mixed-precision training with the Adam optimizer, the memory usage for each item we need to store is:</p> | |
| <ul> | |
| <li>Model’s parameters (half precision i.e. bf16/fp16): $2\Psi$</li> | |
| <li>Model’s gradients (half precision i.e. bf16/fp16): $2\Psi$</li> | |
| <li>Model’s parameters in fp32 and optimizer states: $4\Psi + (4\Psi + 4\Psi)$</li> | |
| <li>Model’s gradients in fp32: $4\Psi$ (optional, only accounted if we want to accumulate grads in fp32)</li> | |
| </ul> | |
| <p>If we don’t accumulate gradients in fp32 this gives us a total memory consumption of $2\Psi + 2\Psi + 12\Psi$, and if we accumulate it would be $2\Psi + 6\Psi + 12\Psi$. Let’s focus for now on the case without fp32 gradient accumulation for simplicity but you can just add the additional bytes to the gradient term which are affected by ZeRO-2 and 3. </p> | |
| <p>The idea of ZeRO is to shard these objects across the DP ranks, each node only storing a slice of the items which are reconstructed when and if needed, thereby dividing memory usage by the data parallel degree $N_d$:</p> | |
| <p><img alt="Memory consumption of DP and three stages of Zero-DP. $\Psi$ denotes number of parameters, $k$ denotes the memory multiplier of optimizer states ($k=12$ for Adam), and $N_d$ denotes DP degree." src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2015.png" /></p> | |
| <p>Memory consumption of DP and three stages of Zero-DP. $\Psi$ denotes number of parameters, $k$ denotes the memory multiplier of optimizer states ($k=12$ for Adam), and $N_d$ denotes DP degree.</p> | |
| <p>Let’s explain this graph and it’s values by exploring how each ZeRO stage works. We’ll start with ZeRO-1.</p> | |
| <h3>ZeRO-1: Partitioning Optimizer States</h3> | |
| <p>In vanilla DP, all ranks gather the same gradients after the backward pass and simultaneously perform identical optimizer steps. This seems like a lot of duplicated work. Can we avoid it and reduce memory usage at the same time?</p> | |
| <p>In ZeRO-1, the optimizer states are partitioned into $N_d$ equal parts where $N_d$ is the DP degree. This means that each model replica that’s distributed on each DP rank only keeps track of $\frac{1}{𝑁_𝑑}$ of the optimizer states. During the optimization step only $\frac{1}{𝑁_𝑑}$ of the float32 weights are updated, which we cast to get the corresponding $\frac{1}{𝑁_𝑑}$ portion of the bfloat16 parameters.</p> | |
| <p>However for the forward pass, we need all our bfloat16 parameters, we thus need to add an additional <a href="https://www.notion.so/The-Ultra-Scale-Playbook-Training-LLMs-on-GPU-Clusters-af1b4137215e4e4eb1971e7dfa3185a9?pvs=21"><strong><em>all-gather</em></strong></a> (the second type of collective communication primitive we encounter!) after the optimizer step so that each model replica has the full set of updated weights. </p> | |
| <p>This explains the memory formula of $2\Psi + 2\Psi + \frac{k\Psi}{N_d}$ that we saw on the above graph! Here’s a summary of the sequence of operations for a single training step</p> | |
| <ul> | |
| <li>Forward pass with all bf16 parameters (but different microbatches across DP ranks)</li> | |
| <li>Backward pass with all gradients (but different microbatches across DP ranks)</li> | |
| <li>Perform an reduce scatter on the gradients (reduce scatter is 2 times faster than all reduce!)</li> | |
| <li>Each replica perform an optimizer step (has only 1/$N_d$ optimizer states) updates only on 1/$N_d$ of fp32 parameters, and then 1/$N_d$ of bf16 parameters</li> | |
| <li>[New operation in ZeRO, not in vanilla DP] Perform an all-gather of bf16 parameters to send missing slices back to each replica</li> | |
| </ul> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2016.png" /></p> | |
| <p>If you've been following along, you'll recall from vanilla DP that we can overlap the all-reduce gradient communication with the backward pass computation. In ZeRO-1, we can also investigate how to efficiently overlap the newly added all-gather of bf16 parameters. There are two main strategies for this:</p> | |
| <ol> | |
| <li>Right after optimizer step: We can initiate the all-gather immediately after the optimizer updates the parameters. This allows the communication to potentially overlap with other post-optimization operations.</li> | |
| <li>Right before forward: We can delay the all-gather until just before we need the parameters for the next forward pass. This approach gives us more flexibility to overlap with any computation happening between training steps.</li> | |
| </ol> | |
| <p>But unfortunately these techniques are not as evident to implement as they seem and require sophisticated use of hooks / bucketing. In practice we can just use Zero3 / FSDP implementation where the FSDPUnit is the entire model, more details about this later..</p> | |
| <h3>ZeRO-2: Adding <strong>Gradient Partitioning</strong></h3> | |
| <p>In ZeRO-1 the optimizer states have been partitioned, which means that each replica only updates $\frac{1}{N_d}$ of the optimizer states. The keen reader must have noticed that there is no real need to have all gradients on all DP ranks in the first place since only a subset is needed for the optimization step. </p> | |
| <p>→ During the backward pass, instead of performing an all-reduce over the gradients, we can therefore perform a <strong><em>reduce-scatter [TODO: add link]</em></strong> operation! <em>(yay, a third communication primitive!)</em> Where we only spread the $\frac{1}{N_d}$ gradients needed in memory, thus saving more memory compared to ZeRO-1</p> | |
| <blockquote> | |
| <p>In case of FP32 gradient accumulation, we only need to keep $\frac{1}{N_d}$ fp32_grads where we accumulate the bf16 grads coming from the reduce-scatter. And in the optimizer step we use the $\frac{1}{N_d}$ fp32_grads. | |
| </p> | |
| </blockquote> | |
| <p><img alt="zero-1.gif" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/zero-1.gif" /></p> | |
| <p>It’s easy to see now that sharding the gradients leads to to $2\Psi + \frac{2\Psi+k\Psi}{N_d}$ and as $N_d$ is increased we can save up to 8x memory over the baseline. In terms of communication the same process applies as for ZeRO-1, with the only difference that we communicate and release on the fly. In total, ZeRO-2 is thus also equivalent to vanilla DP training w.r.t. communication.</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2017.png" /></p> | |
| <blockquote> | |
| <p>Note: You might notice that there is no real overhead of using ZeRO-2 over ZeRO-1 and indeed ZeRO-2 is usually the best option. The reason some distributed training frameworks don’t support it is that gradient sharding may interfere with and make more complex other parallel strategies we discussed later. | |
| </p> | |
| </blockquote> | |
| <p>Now that we’ve sharded gradients as well, we are we done? Or can we keep getting away with this? Well, sort of. We would like to reduce the memory of the parameters as well, and we’ve seen that we don’t need to wait for the entire all-gather to start the forward, we can already start the forward once we get the first layer.. here comes ZeRO-3!</p> | |
| <h3>ZeRO-3: Adding Parameter <strong>Partitioning</strong></h3> | |
| <p>For Stage 3 we extend the above approach of sharding tensors over DP replicas up to sharding the model’s parameters.</p> | |
| <blockquote> | |
| <p>Note: This stage is also called <strong>FSDP</strong> (Fully Shared Data Parallelism) in PyTorch native implementation. We’ll just refer to ZeRO-3 in this blogpost but you can think of FSDP wherever you see it. | |
| </p> | |
| </blockquote> | |
| <p>So how do we do a forward or backward pass in practice if all parts of the model are distributed? Quite simply we gather them on-demand when we need them. In the forward pass this looks as follows:</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2018.png" /></p> | |
| <p>So as we perform the forward pass and sequentially go through the layers we retrieve the necessary parameters on demand and immediately flush them from memory when we don’t need them anymore. The backward pass works the same way just inverted in flow and we produce the gradient shards: </p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2019.png" /></p> | |
| <p>During the forward pass we do all-gather operations for the parameters when we need them, so a $\Psi$ communication tax. Since we discard the parameters immediately after we needed them in the forward pass we need one more all-gather during the backward pass as well incurring another $\Psi$ in communication tax. Finally we need the same <strong><em>reduce-scatter</em></strong> as in ZeRO-2 for the gradients which costs also $\Psi$ in communication and we arrive at a total communication cost of $3\Psi$, compared to $2\Psi$ for Zero-2. </p> | |
| <p>The other issue is that we need to do these all-gathers continuously throughout the forward and backward step, which amounts to <code>2 * num_layers - 1</code> additional all-gathers in a training step compared to Zero-2 as we can see in the following figure:</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2020.png" /></p> | |
| <p>$$ | |
| \frac{t_{comm}}{t_{compute}} = \frac{(DP-1) \cdot peak_{flops}}{2 \cdot seq \cdot mbs \cdot peak_{bw}} | |
| $$</p> | |
| <p>Overall it may sound like we significantly increase communication overhead, but thanks to <strong>prefetching</strong> we can start all-gathering weights for Layer n+1 while we do the current forward for Layer n which usually overlaps communication and computation as long as we don’t scale DP too much (as a rule of thumb: DP<512).</p> | |
| <p>In terms of memory we can see that our equation now reached it’s final form of $\frac{2\Psi +2\Psi+k\Psi}{N_d}$ which means we can drive memory usage down indefinitely if we can increase the DP rank, at least for the model related parameters. Notice how it doesn’t specifically help with the intermediate activations that we discussed in the previous chapter. ZeRO is an orthogonal technique to the activation checkpointing and gradient accumulation we discussed in other chapters. </p> | |
| <blockquote> | |
| <p>If you want to read more about FSDP1 vs FSDP2 and some of the implementation complexities around them, you should take some time to go over this nice blog: <a href="https://christianjmills.com/posts/mastering-llms-course-notes/conference-talk-012/">https://christianjmills.com/posts/mastering-llms-course-notes/conference-talk-012/</a> | |
| </p> | |
| </blockquote> | |
| <p><strong>Let’s summarize our journey into DP and ZeRO so far: we have seen that we can increase throughput of training significantly with DP, simply scaling training by adding more model replicas. With ZeRO we can train even models that would ordinarily not fit into a single GPU by sharding the parameters, gradients and optimizers states across DP, while incurring a small communications cost.</strong></p> | |
| <p>However, there is a limit here, DP only works if a layer of the model fits in a single GPU and ZeRO can only reduce the parameters, gradients, and optimizer states, but not the activation memory! Recall from the activation memory discussion that it scales with sequence length and batch size. Naturally we could just limit those, but in practice we don’t want to be limited by hardware to train with e.g. short sequence length. </p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2021.png" /></p> | |
| <p>As model grow bigger and even a single layer may not fit in GPU, we need more tool in our distributed training toolbox to scale more.</p> | |
| <p>Let’s explore a totally different approach called Tensor Parallelism which can, by the way, be combined with ZeRO.</p> | |
| <h1>Tensor Parallelism</h1> | |
| <p>So we have sharded the model’s parameters, gradients and optimizers states with ZeRO but we hit a limit once activation memory overtakes our memory budget. Welcome Tensor Parallelism (TP), a method which shards weights, gradients, and optimizers states as well as activations and without the need to gather them all prior to the computation. Seems like a dream :) Let’s first have a look at how Tensor Parallel works with simple matrix multiplications.</p> | |
| <p>Tensor Parallelism leverages the mathematical properties of matrix multiplication <code>A × B</code>. To understand how it works, let's examine two fundamental equations that make this parallelization possible:</p> | |
| <p>$$ | |
| \begin{aligned} | |
| \text{1.} \quad A\cdot B = A \cdot \begin{bmatrix} | |
| B_1 & B_2 & \cdots | |
| \end{bmatrix} &= | |
| \begin{bmatrix} | |
| AB_1 & AB_2 & \cdots | |
| \end{bmatrix} | |
| \end{aligned}</p> | |
| <p>$$</p> | |
| <p>$$ | |
| \begin{aligned}\text{2.} \quad A\cdot B =\begin{bmatrix} A_1 & A_2 & \cdots \end{bmatrix} \begin{bmatrix} B_1 \ B_2 \ \vdots \end{bmatrix} &= \sum_{i=1}^n A_i B_i\end{aligned} | |
| $$</p> | |
| <p>This means that we can compute matrix product by either 1) multiplying each column of $B$ individually or 2) multiplying each row individually and combining the results. In a neural network, the matrix multiplication is more often represented in the following format: <code>X × W</code>, where:</p> | |
| <ul> | |
| <li>X represents the input or activation values</li> | |
| <li>W represents the weight of the <code>nn.Linear</code></li> | |
| </ul> | |
| <p>In practice a small example of the operation looks like this:</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2022.png" /></p> | |
| <p>Let’s see how we can parallelise this operation! In tensor parallelism, tensors will be split into N shards along a particular dimension and distributed across N GPUs. Matrices can be split either on the column part or row part leading to row and column parallelism. One thing we’ll see in the following is that choosing row or column sharding will require different communications primitives.</p> | |
| <p>Our first option is to use column-wise sharding (also called <strong><em>column-linear</em></strong>): We'll copy the complete input matrices to each worker, requiring an operation called <a href="https://www.notion.so/The-Ultra-Scale-Playbook-Training-LLMs-on-GPU-Clusters-af1b4137215e4e4eb1971e7dfa3185a9?pvs=21"><strong><em>broadcast</em></strong></a>, and split the weight matrix into columns. The inputs are then multiplied with the partial weight matrices, and the results are finally combined using an <a href="https://www.notion.so/The-Ultra-Scale-Playbook-Training-LLMs-on-GPU-Clusters-af1b4137215e4e4eb1971e7dfa3185a9?pvs=21">***all-gather</a>*<em><em> operation</em>.</em></p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2023.png" /></p> | |
| <p>The second option is called row-wise sharding (also called <strong><em>row-linear</em></strong>): As the attentive reader might guess, row-linear means that we split the weight matrix into chunks of rows. However, this also requires us to split the inputs, which needs a <strong><em>scatter</em></strong> operation rather than a broadcast as used in column-linear sharding. The results on each worker are already in the right shape but need to be summed for the final result, thus requiring an all-reduce operation in this scenario.</p> | |
| <p>We see here our fourth distributed primitive: <strong><em>s<a href="https://www.notion.so/The-Ultra-Scale-Playbook-Training-LLMs-on-GPU-Clusters-af1b4137215e4e4eb1971e7dfa3185a9?pvs=21">catter</a></em></strong>!</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2024.png" /></p> | |
| <h2>Tensor Parallelism in a Transformer Block</h2> | |
| <p>To come up with a strategy to follow, let’s move from a toy example to a real model building block. A Transformer model is made of two main building blocks : Feedforward layers (MLP) and Multi-Head Attention (MHA). We can apply tensor parallelism to both.</p> | |
| <p>The Feedforward part can be parallelized by having a “Column linear” followed by a “Row Linear” which amounts to a broadcast to copy the input and an all-reduce in forward. Note that the broadcast isn’t needed in actual training where we can make sure inputs are already synced across TP ranks.</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2025.png" /></p> | |
| <p>Now that we’ve found the most efficient schema for the Feedforward part of the transformer, let’s take a look at the multi-head attention block (MHA). </p> | |
| <p>We can generally follow a similar approach where Q, K, and V matrices are split in a column-parallel fashion, and the output projection is split along the row dimension. With multi-head attention, the column-parallel approach has a very natural interpretation: each worker computes the attention for an individual or a subset of heads. The same approach works as well for <a href="https://arxiv.org/abs/1911.02150"><strong><em>multi-query</em></strong> (MQA)</a> or <a href="https://arxiv.org/abs/2305.13245"><strong><em>grouped query attention</em></strong> (GQA)</a> where key and values are shared between queries. </p> | |
| <p>It's also worth noting that the tensor parallelism degree should not exceed the number of Q/K/V heads because we need intact heads per TP rank. And in case we’re using GQA, TP degree should be below number of K/V heads, otherwise it requires additional comms to keep them in sync. For instance, LLaMA-3 8B has 8 Key/Value heads, so the tensor parallelism degree should be less than or equal to 8, otherwise if TP=16 for example, we need to duplicate each K/V head and make sure they stay in sync.</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2026.png" /></p> | |
| <p>Finally note that there is a tradeoff in terms of communication as we’ve added several distributed communication primitive directly in the computation path of our model. At the difference of ZeRO where we could prefetch, it can be harder to make these communication fully overlap with computations. </p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2027.png" /></p> | |
| <p>Looking at the timeline of operations in tensor-parallel MLP (same applies for Attention), we can better understand the tradeoffs involved. In the forward of each decoder layer, we hit a synchronization point with the AllReduce operation that cannot be overlapped with computation. This <em>exposed communication</em> overhead is necessary to combine partial results across tensor-parallel ranks before the final LayerNorm can be applied. This illustrates one of the key challenges with tensor parallelism - while it helps distribute large matrix multiplications, it does not actually reduce the total memory pressure since activations still need to be gathered for operations like LayerNorm. Additionally, it introduces communication requirements that heavily depend on the network infrastructure. The inability to hide this particular AllReduce behind computation means it directly adds to the critical path of the forward pass. </p> | |
| <p><img alt="Impact of Tensor Parallelism on model performance and batch size capacity: while increasing TP leads to reduced per-GPU throughput (left), it enables processing of larger batch sizes (right), illustrating the trade-off between computational efficiency and memory availability in distributed training." src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2028.png" /></p> | |
| <p>Impact of Tensor Parallelism on model performance and batch size capacity: while increasing TP leads to reduced per-GPU throughput (left), it enables processing of larger batch sizes (right), illustrating the trade-off between computational efficiency and memory availability in distributed training.</p> | |
| <p>In practice, the communication overhead of tensor parallelism becomes particularly noticeable as we scale beyond 8 GPUs. While tensor parallelism within a single node can leverage fast NVLink interconnects, going across nodes requires slower network connections. As shown in the throughput plot above, we observe significant drops when moving from TP=8 to TP=16, and an even steeper decline from TP=16 to TP=32. This illustrates how communication costs can dominate at higher degrees of parallelism.</p> | |
| <p>However, tensor parallelism provides important benefits for memory usage by distributing model parameters, gradients and optimizer states across GPUs. Let's examine this effect on a 70B parameter model:</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2029.png" /></p> | |
| <p>As we can see, increasing tensor parallelism reduces the memory needed for model parameters, gradients and optimizer states on each GPU. However, the activation memory remains constant across TP configurations. This is because operations like layer normalization and dropout require gathering the full activations on each GPU, effectively negating the memory savings we gained by sharding activations in the attention and feedforward layers.</p> | |
| <blockquote> | |
| <p>One interesting note about layer normalization in tensor parallel training - since each TP rank sees the same activations after the all-gather, the layer norm weights don't actually need an all-reduce to sync their gradients after the backward pass. They naturally stay in sync across ranks. However, for dropout operations, we must make sure to sync the random seed across TP ranks to maintain deterministic behavior. | |
| </p> | |
| </blockquote> | |
| <p>This raises an interesting question - could we extend tensor parallelism to these remaining operations as well? Indeed, it's possible to parallelize layer norm, dropout and other operations too, which we'll explore next.</p> | |
| <h2>Sequence Parallelism</h2> | |
| <p>In regions where we apply tensor parallelism (TP), like attention and feedforward layers, each GPU only needs to operate on a portion of the hidden dimension since the weights are sharded. However, operations like layer norm or dropout (which is not used a lot anymore in LLM) require access to the full hidden dimension to compute correctly.</p> | |
| <p>Rather than gathering the full hidden dimension on each GPU (which would defeat the memory benefits of TP), we can instead shard these operations along the sequence length dimension. This approach is called <strong>sequence parallelism (SP)</strong>.</p> | |
| <blockquote> | |
| <p>Note that the term Sequence Parallelism is a bit overloaded: the Sequence Parallelism in this section is tightly coupled to Tensor Parallelism and applies to dropout and layer norm operation. However, when we will move to longer sequences the attention computation will become a bottleneck, which calls for techniques such as Ring-Attention, which are sometimes also called <em>Sequence Parallelism</em> but we’ll refer to them as <em>Context Parallelism</em> to differentiate the two approaches. | |
| So each time you see sequence parallelism, remember that it is used together with tensor parallelism (in contrast to context parallelism, which can be used independently). | |
| </p> | |
| </blockquote> | |
| <p>Sequence parallelism (SP) involves splitting the activations and computations for the parts of the model not handled by tensor parallelism (TP) such as Dropout and LayerNorm, but along the input sequence dimension rather than across hidden dimension. This is needed because these operations require access to the full hidden dimension to compute correctly. For example, LayerNorm needs the full hidden dimension to compute mean and variance:</p> | |
| <p>$$ | |
| LayerNorm(x) = γ * (x - μ) / √(σ² + ε) + β | |
| $$</p> | |
| <p>where μ = mean(x) and σ² = var(x) are computed across hidden dimension h</p> | |
| <p>So even though these operations are computationally cheap, they still require significant activation memory since they need the complete hidden dimension. SP allows us to shard this <strong>memory</strong> burden across GPUs by splitting along the sequence dimension instead.</p> | |
| <p>In practice we’ll go from the left diagram to the right:</p> | |
| <p><img alt=" in forward: f = no-op ; f* = all-reduce ; g = all-gather ; g* = reduce-scatter | |
| in backward: f = all-reduce ; f* = no-op ; g = reduce-scatter ; g* = all-gather | |
| SP region needs full hidden_dim" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2030.png" /></p> | |
| <p>in forward: f = no-op ; f<em> = all-reduce ; g = all-gather ; g</em> = reduce-scatter | |
| in backward: f = all-reduce ; f<em> = no-op ; g = reduce-scatter ; g</em> = all-gather | |
| SP region needs full hidden_dim</p> | |
| <p>The diagram shows how we transition between tensor-parallel and sequence-parallel regions using different collective operations (labeled "f" and "g"). The key challenge is managing these transitions efficiently while keeping memory usage low and maintaining correctness.</p> | |
| <p>Let's first understand the operations in tensor parallelism (TP):</p> | |
| <p>In the forward pass: | |
| - "f" is a no-op (no operation) because activations are already duplicated across ranks | |
| - "f*" is an all-reduce to synchronize activations and ensure correctness</p> | |
| <p>In the backward pass: | |
| - "f*" is a no-op because gradients are already duplicated across ranks<br /> | |
| - "f" is an all-reduce to synchronize gradients</p> | |
| <p>These operations "f" and "f*" are called conjugate pairs because they complement each other - when one is a no-op in forward, the other is an all-reduce in backward, and vice versa.</p> | |
| <p>For sequence parallelism (SP), we use different operations labeled "g" and "g*". Specifically, we avoid using all-reduce in the SP region since that would require gathering the full activations and increase our peak memory usage, defeating the purpose of SP.</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/c65c0745-6dda-4f5c-a7ae-0092e50cdc0f.png" /></p> | |
| <p>So what is actually happening here? As a famous LLM would say, let’s take it step-by-step:</p> | |
| <ol> | |
| <li><strong>Initial LayerNorm (SP Region)</strong></li> | |
| <li>Input tensors X1<em> and X2</em> (b,s/2,h) enter LayerNorm, already split across sequence dimension</li> | |
| <li> | |
| <p>Each GPU computes LayerNorm independently on its sequence chunk and give Y1<em> and Y2</em></p> | |
| </li> | |
| <li> | |
| <p><strong>First Transition (SP → TP)</strong></p> | |
| </li> | |
| <li>"g" operation (all-gather) combines Y1<em> and Y2</em> back to full sequence length</li> | |
| <li> | |
| <p>Restores Y (b,s,h) since column linear layer needs full hidden dimension h</p> | |
| </li> | |
| <li> | |
| <p><strong>First Linear Layer (TP Region)</strong></p> | |
| </li> | |
| <li>A1 is a column-linear layer, so it splits Y along the hidden dimension</li> | |
| <li>GeLU is applied independently on each GPU</li> | |
| <li> | |
| <p>Z1* is (b,s,h/2)</p> | |
| </li> | |
| <li> | |
| <p><strong>Second Linear Layer (TP Region)</strong></p> | |
| </li> | |
| <li>B1 is a row-linear layer, so it restores the hidden dimension</li> | |
| <li> | |
| <p>W1 is (b,s,h)</p> | |
| </li> | |
| <li> | |
| <p><strong>Final Transition (TP → SP)</strong></p> | |
| </li> | |
| <li>"g*" operation (reduce-scatter) which reduces for previous row-linear correctness while scattering along sequence dimension</li> | |
| <li>W1* is (b,s/2,h)</li> | |
| </ol> | |
| <p>A key advantage of sequence parallelism is that it reduces the maximum activation size we need to store. In tensor parallelism alone, we had to store activations of shape (b,s,h) at various points. However, with sequence parallelism, the maximum activation size is reduced to $\frac{b \cdot s \cdot h}{tp}$ since we always either split along the sequence or hidden dimensions.</p> | |
| <p>It’s a bit difficult to keep track of all the parts that are sharded differently in TP and TP/SP - believe us, we find it hard to map as well so we made this small table to summarize how the activations (aka <code>hidden_states</code> ) shape change across hidden dimension h and sequence dimension s during a forward pass:</p> | |
| <p>| Region | Vanilla TP | TP with SP | | |
| | --- | --- | --- | | |
| | Enter TP (Column Linear) | h: sharded (weight_out is sharded) | |
| s: full | h: sharded (weight_out is sharded) | |
| s: <strong>all-gather</strong> to full | | |
| | TP Region | h: sharded | |
| s: full | h: sharded | |
| s: full | | |
| | Exit TP (Row Linear) | h: full (weight_out is full + <strong>all-reduce</strong> for correctness) | |
| s: full | h: full (weight_out is full + <strong>reduce-scatter</strong> for correctness) | |
| s: <strong>reduce-scatter</strong> to sharded | | |
| | SP Region | h: full | |
| s: full | h: full | |
| s: sharded |</p> | |
| <p>And for the embedding layer</p> | |
| <p>| Region | Vanilla TP | TP with SP | | |
| | --- | --- | --- | | |
| | Embedding Layer (Row Linear sharded on vocab) | h: full (weight_out is full + <strong>all-reduce</strong> for correctness) | |
| s: unchanged | h: full (weight_out is full + <strong>reduce-scatter</strong> for correctness) | |
| s: <strong>reduce-scatter</strong> to sharded |</p> | |
| <p>Does that mean that SP incurs more communication than TP? Well, yes and no. In the forward of a vanilla TP we had two all-reduce per transformer block, and in SP we have two all-gather and two reduce-scatter per transformer block. So SP does twice the number of communication operations as TP. But since an all-reduce operation can be broken down into to an all-gather + reduce-scatter (see in [TODO: Appendix link]) they’re actually equivalent in terms of communication. Same reasoning for backward as we just use the conjugate of each operation (no-op ↔ allreduce and allgather ↔ reducescatter).</p> | |
| <p>You can find an example of implementation of both column and row linear TP in picotron: | |
| <a href="https://github.com/huggingface/picotron/blob/main/picotron/tensor_parallel/tensor_parallel.py">https://github.com/huggingface/picotron/blob/main/picotron/tensor_parallel/tensor_parallel.py</a> </p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2031.png" /></p> | |
| <p>If you’ve been paying close attention, you’ll notice that we’re talking about 4 comms ops <strong>IN EACH LAYER</strong> (2 for Attention and 2 for MLP), as shown here for the MLP region:</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2032.png" /></p> | |
| <p>Besides the fact that TP requires communication in each layer, it also can’t easily be overlapped with compute, which makes throughput heavily dependent on the communication bandwidth. This is why TP is usually done only within a node (TP≤8)</p> | |
| <blockquote> | |
| <p>Notice how all-gather is overlapped with “Y A1” that’s thanks to this trick | |
| <a href="https://github.com/huggingface/nanotron/blob/9055c664c28a3b430b4e53bfcb5a074068c90f2a/src/nanotron/parallel/tensor_parallel/functional.py#L169-L262">https://github.com/huggingface/nanotron/blob/9055c664c28a3b430b4e53bfcb5a074068c90f2a/src/nanotron/parallel/tensor_parallel/functional.py#L169-L262</a> | |
| and you can find more tricks <a href="https://discuss.pytorch.org/t/distributed-w-torchtitan-introducing-async-tensor-parallelism-in-pytorch/209487">here</a>. | |
| </p> | |
| </blockquote> | |
| <p>TODO: remove, Profiling:</p> | |
| <ul> | |
| <li>TP</li> | |
| </ul> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2033.png" /></p> | |
| <ul> | |
| <li>Seq Parall</li> | |
| </ul> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2034.png" /></p> | |
| <p>Allreduce takes almost double the duration (900us) of reducescatter and allgather (500us)</p> | |
| <p>Let’s compare throughput as we scale TP and TP/SP for a 3B model:</p> | |
| <p><img alt="Impact of combined Tensor and Sequence Parallelism (TP/SP) on model performance and memory utilization: when scaling both TP and SP together, there's a trade-off between computational efficiency (left) and memory capacity (right). While higher parallelism degrees reduce per-GPU throughput, they enable processing of significantly larger batch sizes by reducing the activation memory." src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2035.png" /></p> | |
| <p>Impact of combined Tensor and Sequence Parallelism (TP/SP) on model performance and memory utilization: when scaling both TP and SP together, there's a trade-off between computational efficiency (left) and memory capacity (right). While higher parallelism degrees reduce per-GPU throughput, they enable processing of significantly larger batch sizes by reducing the activation memory.</p> | |
| <p>Let’s summarize our observations:</p> | |
| <ul> | |
| <li>for both methods we notice the biggest performance drop when we move from TP=8 to TP=16, because that’s when we move from only communicating within a single node (NVLink), to communicating inter-nodes (EFA)</li> | |
| <li>the memory savings in activations when using TP with SP helps us fit far bigger batches than TP alone</li> | |
| <li>the Torch memory fragmentation makes it hard for us to predict the exact peak reserved memory. For more details check memory_viz tool section. [TODO: add link]</li> | |
| </ul> | |
| <p>TODO (outro): TP can help sharding activs (sometimes on hidden_dim, sometimes on seq_dim) by sharding the big linears across ranks, but what if we want to scale sequence_length, our activs will still blow up in TP region. → Context parallelism</p> | |
| <p><strong>We have seen how TP helps us shard activations across several GPUs by splitting the attention and feedforward operations along the hidden dimension and how SP is a natural complement for the remaining operations by splitting along the sequence dimension.</strong></p> | |
| <p>However, there are two limits to TP and SP: 1) if we scale the sequence length the activation memory will still blow up in the TP region and 2) if the model is too big to fit with TP=8 then we will see a massive slow-down due to the inter-node connectivity. </p> | |
| <blockquote> | |
| <p>Note: Since LayerNorms in the SP region operate on different portions of the sequence, their gradients will differ across TP ranks. To ensure the weights stay synchronized, we need to allreduce their gradients during the backward pass, similar to how DP ensures weights stay in sync. This is a small communication overhead since LayerNorm has relatively few parameters.</p> | |
| </blockquote> | |
| <p>We can tackle problem 1) with Context parallelism and problem 2) with Pipeline parallelism. Let’s first have a look at Context parallelism! </p> | |
| <h1>Context Parallelism</h1> | |
| <p>With Tensor Parallelism and Sequence Parallelism, we can reduce the memory requirements per GPU significantly as both model weights and activations are distributed across GPUs. However, when training models on longer and longer sequences (e.g. when scaling to 128k or more tokens per sequence) we might still exceed the memory available on a single node, because inside the TP region we still have to process a full sequence length.</p> | |
| <p>Even if we use full recomputation of the activations, which comes at a heavy compute overhead (30%), we still need to hold in memory some activations at the layer boundaries which scale linearly with sequence length:</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2036.png" /></p> | |
| <p>Can we apply similar ideas to our sequence parallelism approach but inside in the modules where we apply Tensor Parallelism already, thereby also reducing the effect of sequence length? Yes, it’s time to talk about Context Parallelism, which you will find quite intuitive after all we’ve already convered.</p> | |
| <h2>Introducing Context Parallelism</h2> | |
| <p>The idea of Context Parallelism is quite simple; just like Sequence Parallelism, we’ll split the input along the sequence dimension but we now apply this splitting along the full model, instead of only the sequence parallel regions of the model. Our focus here will be to reduce the activation memory footprint by splitting the long sequences, complementing parallelism strategies like TP which target the hidden dimension of the model.</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2037.png" /></p> | |
| <p>Splitting the sequence doesn't affect most modules like MLP and LayerNorm, where each token is processed independently. It also doesn’t require expensive communication like TP, as only the inputs are split and not the weight matrices. Just as in data parallelism, after computing the gradients, an all-reduce operation is initiated to synchronize the gradients across the context parallelism group.</p> | |
| <p>There is one important exception though, which is the <strong><em>attention module</em></strong>. In this module each token needs to access key/value pairs from <strong>all</strong> other sequence tokens or in the case of causal attention at least attends to each previous token.</p> | |
| <p>Because Context Parallelism splits the inputs along the sequence dimension across GPUs, the attention module requires full communication between GPUs to exchange the necessary key/value data.</p> | |
| <blockquote> | |
| <p>Note: there is another trick to reduce the attention memory which is Flash Attention that we’ll look at in a bit [TODO link]. However, compared to Context parallelism we can scale the reduction by increasing a rank. | |
| </p> | |
| </blockquote> | |
| <p>TODO: talk abt flashattn. we either use flashattn or CP to reduce seqlen memory</p> | |
| <p>That sounds very expensive if we do it naively. Is there a way to do this rather efficiently and fast! Thankfully there is: a core technique to handle this communication of key/value pairs efficiently is called <em>Ring Attention.</em></p> | |
| <h2>Discovering Ring Attention</h2> | |
| <p>In this implementation of attention, each GPU first initiates a communication operation to send its key/value pairs to other GPUs. While waiting for the other GPUs data, it computes the attention score for the portion of the data it already has in memory. Ideally, a next key/value pair is received from another GPU before this computation finishes, allowing the GPU to start the next round of computation immediately after it finishes its first computation.</p> | |
| <p>To illustrate this, let's suppose we have 4 GPUs and an input of 4 tokens. Initially, the input sequence is split evenly along the sequence dimension, so each GPU will have just one token along with its corresponding Q/K/V values. For example, Q1, K1, and V1 represent the query, key, and value of the first token, which are located on the 1st GPU. The attention calculation will take 4 time steps to complete. At each time step, each GPU follows these 3 stages:</p> | |
| <ol> | |
| <li>Send “current keys and values” to the next machine (except during the last time step) (in a non-blocking manner so it starts the following step before this step is finished)</li> | |
| <li>Locally compute the attention score on the “current keys and values” it already has, which typically involves performing $Softmax(\frac{QK^T}{\sqrt{d}}) * V$.</li> | |
| <li>Wait to receive keys and values from the previous GPU and then move to step 1 with “current keys and values” being now the key/values just received from the previous GPU.</li> | |
| </ol> | |
| <p>The whole process with 4 GPUs is shown in the following animation:</p> | |
| <p><img alt="ring-attention.gif" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/ring-attention.gif" /></p> | |
| <p>With this animation, it’s also immediately clear why the authors chose to call this approach Ring Attention.</p> | |
| <p>There is one big problem though which is that a naive implementation of Ring Attention lead to some strong imbalance between GPU streaming from the shape of the causal attention matrix. Let’s take a real look at what is happening in the SoftMax computation by considering the attention score matrix with the causal attention mask:</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2038.png" /></p> | |
| <p>The SoftMax is computed row-wise, which means whenever a GPU has received all the tokens of a row it can be computed. We see that GPU1 can immediately compute it as it starts with tokens 1-4 and GPU1 actually doesn’t need to receive any information from any other GPUs. However, GPU2 will need to wait for the second round to also receive 1-4 and thus have all values for tokens 1-8. Also, GPU1 seems to perform much less work than all the other GPUs.</p> | |
| <p>Let’s see if we can balance our computations better:</p> | |
| <h2>Zig-Zag Ring Attention – A Balanced Compute Implementation</h2> | |
| <p>We need a better way to distribute the input sequences. This can be achieved by assigning the tokens not purely sequential to the GPUs but by mixing the ordering a bit such that we have a good mix of early and late tokens on each GPU. This approach is called <a href="https://arxiv.org/pdf/2311.09431">Zig-Zag attention</a> and in this new arrangement, the attention mask will show an even distribution of computation but if you count the number of colored squares, you’ll see that the computation is now balanced across all GPUs.</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2039.png" /></p> | |
| <p>At the same time we’ll also see that in order to complete all rows, each GPU will need information from all the other GPUs.</p> | |
| <p>We have two general ways to overlap computation and communication, either by performing a general all-gather, regrouping all the KV on each GPUs at the same time (in a Zero-3 type of way) or we gather them one-by-one from each GPU to each GPU as needed:</p> | |
| <p><img alt="Context Parallelism using AllGather implementation" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2040.png" /></p> | |
| <p>Context Parallelism using AllGather implementation</p> | |
| <p><img alt="Context Parallelism using All-to-All implementation" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2041.png" /></p> | |
| <p>Context Parallelism using All-to-All implementation</p> | |
| <p>TODO: add links to megatronlm(AllGather) and deepspeed(All2All) implementations</p> | |
| <h1></h1> | |
| <h1>Pipeline Parallelism</h1> | |
| <p>In the TP section we saw that if we try to scale Tensor parallelism past the number of GPUs per single node (typically 4 or 8) we hit a lower bandwidth network called “inter-node connection” which can quite strongly impair our performances. We can see this clearly on e.g. the all-reduce operation when we perform it across several nodes:</p> | |
| <p><img alt="Inter-node communication bandwidth measurements across different node counts, showing median (lines) and 5th-95th percentile ranges (shaded areas) for AllReduce, AllGather and ReduceScatter operations." src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2042.png" /></p> | |
| <p>Inter-node communication bandwidth measurements across different node counts, showing median (lines) and 5th-95th percentile ranges (shaded areas) for AllReduce, AllGather and ReduceScatter operations.</p> | |
| <p>Sequence and context parallelism can help for long sequences but don’t help much if sequence length is not the root cause of our memory issues but rather the size of the model itself. For large model (70B+), the size of the weights alone can already push past the limits of the 4-8 GPUs on a single node. We can solve this issue by summoning the fourth (and last) parallelism dimension: “pipeline parallelism”.</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2043.png" /></p> | |
| <p>Pipeline Parallelism is conceptually very simple –we’ll simply spread the layers of our model across GPUs – but the devil lies in implementing it efficiently. Let’s dive in it!</p> | |
| <h2>Splitting layers on various nodes - All forward, all backward</h2> | |
| <p>So, let’s say we simply spread the layers on several devices, e.g. a first GPU will take the first few layers and a second GPU will take the second part of the models and so on. The forward pass through our model now simply involves sequentially passing the batch of data along the model and thus successively using each compute device.</p> | |
| <p>We have a direct first advantage: the required interconnect bandwidth stays quite low as we only send moderate-sized activations at a handful of location along the model depth. This is a huge difference e.g. compared to the communication in Tensor Parallelism, happening several times within each layer.</p> | |
| <p>But maybe you start feeling a glimpse of the troubles to come: “sequentially” and “successively”?!? This doesn’t sound very efficient in the world of parallel computation, especially after our discussion about computation and communication overlap.</p> | |
| <p>Indeed reader! The main challenge in pipeline parallelism will be how to efficiently circumvent the sequential nature of PP to keep our GPU busy at all times and avoid having one GPU computing while the others are waiting. Here is how our GPU utilization is looking when doing a naive and simple forward and backward pass through the model where the numbers indicate the model layers:</p> | |
| <p><img alt="An example of Pipeline parallelism for a model with 16 layers distributed across 4 GPUs. The numbers correspond to the layer IDs." src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2044.png" /></p> | |
| <p>An example of Pipeline parallelism for a model with 16 layers distributed across 4 GPUs. The numbers correspond to the layer IDs.</p> | |
| <p>The remaining idle time is indicated in grey and usually called the “bubble” and the sight of this probably break your heart after we spent so much time optimizing throughput.</p> | |
| <p>We can quantify how efficient a pipeline setup is by looking at how much time we loose because of the bubble. Let’s say $t_f$ and $t_b$ are the times for the forward and backward pass, respectively, as measured for one microbatch and one stage of the pipeline (a simple assumption is often to have $t_b \approx 2 \times t_f$ which you can see on the above graph). If we could perfectly parallelize the ideal total time would be $t_{id}=t_f + t_b$. However, we can count on the graph that due to the pipeline bubble there is additional time of $t_{pb}=(p-1)*(t_f+t_b)$ (where $p$ is the degree of pipeline parallelism, i.e the number of GPU on the above graph) ie. the time each GPU is waiting while other GPUs are computing.</p> | |
| <p>We can compute the ratio of the additional bubble time over the ideal time:</p> | |
| <p>$$ | |
| r_{bubble} = \frac{(p-1)*(t_f+t_b)}{t_f+t_b} = p-1 | |
| $$</p> | |
| <p>As we add more stages the bubble time thus increases and the utilization drops.</p> | |
| <p>Thankfully, various pipeline parallelism schemes have been designed to reduce the size of the bubble which as you can see on this naive example can be very large in a naive implementation.</p> | |
| <p>Let’s take a first tool out of our toolbox and think about splitting our batch into smaller bit-sized portions which can be processed in parallel or almost, like we did before in data parallel for instance. Now when the second GPU is busy processing micro-batch 1, the first GPU can already start processing micro-batch 2. Here is a schedule using 8 micro-batches:</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2045.png" /></p> | |
| <blockquote> | |
| <p>Note: before the numbers indicated the layers but in all pipeline parallel plots from now including this one it indicates a microbatch. You can think of each square here to contain several layers as seen in the previous figure. | |
| </p> | |
| </blockquote> | |
| <p>The above schedule is called the <strong>all-forward-all-backward</strong> <strong>(AFAB)</strong> schedule as we first do all forward passes and then only all-backward passes. The advantage is that forward and backward steps are still generally sequential and so preserving the general order of model training. This make this option rather simple to implement.</p> | |
| <p>Let’s take a look at the high level training loop:</p> | |
| <p>```python | |
| def pipeline_parallel_afab(model, data_loader, tensor_shapes, device): | |
| logging_loss, input_tensors, output_tensors = 0.0, [], [] | |
| # All forward passes | |
| for _ in range(data_loader.num_local_micro_batches): | |
| input_tensor = communicate(shapes=tensor_shapes, | |
| dtype=torch.float32, | |
| operation='recv_forward') | |
| batch = next(iter(data_loader)) | |
| batch["hidden_states"] = input_tensor | |
| output_tensor = model.forward(batch, device) | |
| communicate(tensor=output_tensor, | |
| operation='send_forward') | |
| if parallel_context.is_pipeline_last_stage: | |
| output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean') | |
| logging_loss += output_tensor.item() | |
| input_tensors.append(input_tensor) | |
| output_tensors.append(output_tensor)</p> | |
| <pre><code># All backward passes | |
| for _ in range(data_loader.num_local_micro_batches): | |
| output_tensor_grad = communicate(shapes=tensor_shapes, | |
| dtype=torch.float32, | |
| operation='recv_backward') | |
| input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0) | |
| input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad) | |
| communicate(tensor=input_tensor_grad, | |
| operation='send_backward') | |
| return logging_loss | |
| </code></pre> | |
| <p>```</p> | |
| <p>You can find the full implementation of the AFAB pipeline in picotron: <a href="https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/pipeline_parallel/pipeline_parallel.py#L54-L83">https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/pipeline_parallel/pipeline_parallel.py#L54-L83</a></p> | |
| <p>Let’s estimate the bubble in this example. The difference with our first example is that the ideal time to process $m$ microbatches is now $t_{id} = m*(t_f+t_b)$:</p> | |
| <p>$$ | |
| r_{bubble} = \frac{(p-1)<em>(t_f+t_b)}{m</em>(t_f+t_b)} = \frac{p-1}{m} | |
| $$</p> | |
| <p>As we can see, we can fight some inefficiencies of pipeline stages by adding more microbatches, reducing the size of the bubble by a factor of m.</p> | |
| <p>However, as annoying as the bubble is the memory storage required for storing all activation. We need to keep all of the activations in memory until we reach the backward stage which lead to a quick memory explosion in these implementations of PP. Can we do better and avoid this memory explosion?</p> | |
| <p>Since the memory explosion is triggered by the activation we store for the backward pass, let’s try to see if we can start performing the backward pass while we are still performing other forward part of the computation. This will allow us to drop some of the activations we need for the backward pass as soon as possible.</p> | |
| <h2>One-forward-one-backward and LLama 3.1 schemes</h2> | |
| <p>This schedule is called <strong>one-forward-one-backward</strong> <strong>(1F1B)</strong> as the middle/steady state involves alternatively performing one forward and one backward pass. The general idea is to start performing the backward pass as soon as possible. The schedule looks like this:</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2046.png" /></p> | |
| <p>The bubble still has the same size so our training efficiency is not significantly improved. However we only need to store activations for $p$ micro-batches instead of $m$ which quite reduce the activation memory explosion we had in the AFAB schedule. As a consequence we can add more microbatches which then will actually reduce the bubble.</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2047.png" /></p> | |
| <p>A major complexity of this setup, visible on the above graph is how forward and backward passes are not cleanly consecutive anymore but performed in parallel across devices. This means we will have to schedule the switch from forward to backward passes independently on each device instead of in a simple and common central training loop as usual.</p> | |
| <p>This is one of the reason implementing Pipeline Parallelism usually requires rather extensive modifications to training code as well as modeling code.</p> | |
| <p>Here is the example training loop from the above gist:</p> | |
| <p>```python | |
| def pipeline_parallel_1f1b(model, data_loader, tensor_shapes, device): | |
| num_warmup_microbatches = min(parallel_context.pp_world_size - parallel_context.pp_rank - 1, data_loader.num_local_micro_batches) | |
| num_microbatches_remaining = data_loader.num_local_micro_batches - num_warmup_microbatches | |
| logging_loss, input_tensors, output_tensors = 0.0, [], []</p> | |
| <pre><code>def _forward_step(input_tensor): | |
| batch = next(iter(data_loader)) | |
| batch["hidden_states"] = input_tensor | |
| output_tensor = model.forward(batch, device) | |
| if parallel_context.is_pipeline_last_stage: | |
| output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean') | |
| nonlocal logging_loss | |
| logging_loss += output_tensor.item() | |
| return output_tensor | |
| # Warmup forward passes | |
| for _ in range(num_warmup_microbatches): | |
| input_tensor = communicate(shapes=tensor_shapes, | |
| dtype=torch.float32, | |
| operation='recv_forward') | |
| output_tensor = _forward_step(input_tensor) | |
| communicate(tensor=output_tensor, operation='send_forward') | |
| input_tensors.append(input_tensor) | |
| output_tensors.append(output_tensor) | |
| if num_microbatches_remaining > 0: | |
| input_tensor = communicate(shapes=tensor_shapes, | |
| dtype=torch.float32, | |
| operation='recv_forward') | |
| # 1F1B steady state | |
| for i in range(num_microbatches_remaining): | |
| output_tensor = _forward_step(input_tensor) | |
| output_tensor_grad = bidirectional_communicate('send_fwd_recv_bwd', | |
| output_tensor, | |
| tensor_shapes, | |
| torch.float32, | |
| device) | |
| input_tensors.append(input_tensor) | |
| output_tensors.append(output_tensor) | |
| input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0) | |
| input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad) | |
| if i == num_microbatches_remaining - 1: | |
| # last iteration | |
| input_tensor = None | |
| communicate(tensor=input_tensor_grad, | |
| operation='send_backward') | |
| else: | |
| input_tensor = bidirectional_communicate('send_bwd_recv_fwd', | |
| input_tensor_grad, | |
| tensor_shapes, | |
| torch.float32, | |
| device) | |
| # Cooldown backward passes | |
| for _ in range(num_warmup_microbatches): | |
| input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0) | |
| output_tensor_grad = communicate(shapes=tensor_shapes, | |
| dtype=torch.float32, | |
| operation='recv_backward') | |
| input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad) | |
| communicate(tensor=input_tensor_grad, | |
| operation='send_backward') | |
| return logging_loss | |
| </code></pre> | |
| <p>```</p> | |
| <p>You can find the full implementation in picotron as well: <a href="https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/pipeline_parallel/pipeline_parallel.py#L85-L145">https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/pipeline_parallel/pipeline_parallel.py#L85-L145</a></p> | |
| <p>So reordering a bit the computations helped a lot improving the memory pressure from activations. Could we get even better performance with more intricate schedules? Yes! </p> | |
| <h2>Interleaving stages</h2> | |
| <p>This schedule has let us improved memory usage but not much the size of the idle buddle. Can we also also reduce the time spent in the bubble?</p> | |
| <p>Well it turns out this is possible if we are willing to bring in a few additional communications. Time to talk about “<strong>Interleaved Stages</strong>”.</p> | |
| <p>Up to now we’ve sliced our model naively along the model depth dimensions, locating for instance layers 1-4 on the first GPU and layers 5-8 on the second GPU. But there are other ways we could think about slicing our layers, e.g. having odd layers 1, 3, 5, 7 on the first GPU and even layers 2, 4, 6, 8 on the second GPU.</p> | |
| <p>This can be seen in general as a kind of “looping pipeline” where a micro-batch will move in circles from one GPU to the next as it goes through the forward pass through the model.</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2048.png" /></p> | |
| <p>As a consequence we see additional communications happening as the model goes several times through each GPU for the same computation that previously just took one pass. However, each forward and backward pass is divided by a factor of $v$, where $v$ is the number of stages or model chunks per GPUs as we are able to better interleave forward and backward passes. </p> | |
| <p>$$ | |
| t_{pb} = \frac{(p-1)<em>(t_f+t_b)}{v} \ | |
| r_{bubble} = \frac{1}{v}\frac{(p-1)</em>(t_f+t_b)}{m<em>(t_f+t_b)} = \frac{p-1}{v</em>m} | |
| $$</p> | |
| <p>So we can now decrease the bubble by adding microbatches and interleaved stages, but note that quantitatively, the amount of communication also increases by 𝑣 so it’s a trade off. In the following plot you can see several configurations for a PP setup with $p=8$, where the special case of $m=1, v=1$ corresponds to naive pipeline parallelism and the configurations with $v=1$ are AFAB or 1F1B setups and $v \neq 1$ are interleaved configurations.</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2049.png" /></p> | |
| <p>Scheduling also becomes more complex here as we need to decide on a GPU whether we are prioritizing at a given moment earlier micro-batches meaning that we close the forward and backward loops as fast as possible (so called “depth-first”, i.e. prioritizing getting batches out of the model as fast as possible) or we prioritize to first complete the forward passes of all microbatches in the queue before going over to backward passes (so called “breadth-first” i.e. prioritizing filling in the pipeline as much as possible). This is explained in details in <a href="https://arxiv.org/pdf/2211.05953">https://arxiv.org/abs/2211.05953</a>.</p> | |
| <p>You now have all the elements to understand the pipeline parallelism approach in Llama 3.1 which is using a one-forward-one-backward setup with interleaved stages and a priority setting tuneable between depth-first and bread-first.</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2050.png" /></p> | |
| <p>However, we haven’t reached the end of possible pipeline schedules and recently some methods have been proposed to reduce the bubble to virtually zero! Peaked your curiosity? Let’s have a look!</p> | |
| <h2>Zero Bubble and DualPipe</h2> | |
| <p>There are even more sophisticated ways to reduce the bubble more and reached close to a “zero bubble” regime. The secret here is to split at an even finer-grained level the operations involved in order to interleave them in the most efficient way. For instance the pipeline implementation approach in DeepSeek V3/R1, called DualPipe reach close to a zero bubble regime.</p> | |
| <p>Let’s very quickly see how this can work by detailing briefly the <a href="https://arxiv.org/abs/2401.10241">ZeroBubble</a> work which is a precursor to DualPipe. The base observation of ZeroBubble is that a backward through a matrix multiplication involve actually two separated operations: backward for the inputs (B) and the backward for the weights (W):</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2051.png" /></p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2052.png" /></p> | |
| <p>While the output of B, the backward pass for the input, is necessary for performing the backward pass of the lower layers, the backward pass of the weights, W, is not necessary for the rest of the backward pass and generally only need to be performed before the optimiser step. This means W can be flexibly scheduled anywhere after the corresponding B of the same stage. This allows for strategic placement of W to fill the pipeline bubbles. The ZB-H2 schedule on the top right is an example of (theoretical) schedule with zero bubble taking advantage for this fine-grained decomposition.</p> | |
| <p>DeepSeek’s DualPipe propose an extension of this decomposed approach to the case of two stream propagating from both sides of the PP ranks and being interleaved to minimize even further idle time in the GPUs are displayed in the following scheduling graph</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2053.png" /></p> | |
| <p>The ZeroBubble and DualPipe schedules are a bit too complex for us to give here code snippets but you should start to have a general idea of the concepts involved. In practice, optimizing these schedules requires careful measurements of the time for each operations followed by a scheduling algorithm able to find the most optimal allocation of time given the constrains. See for instance in the <a href="https://arxiv.org/abs/2401.10241">ZeroBubble</a> paper for a discussion of the heuristics and algorithms to perform such a scheduling.</p> | |
| <h1>Expert parallelism</h1> | |
| <p>One more ~~thing~~ parallelism.</p> | |
| <p>Mixture-of-expert models have gained some traction with models such as Mixtral or more recently DeepSeek-V3/R1! The basic idea is that instead of having a single feedforward module per layer we can have several and route tokens through different ones depending on their context.</p> | |
| <p>So whereas Context parallelism</p> | |
| <p><img alt="https://arxiv.org/pdf/2407.06204" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2054.png" /></p> | |
| <p><a href="https://arxiv.org/pdf/2407.06204">https://arxiv.org/pdf/2407.06204</a></p> | |
| <p>This design makes it very easy to add a new parallelism paradigm: Expert parallelism (EP). Since the feedforward layers are fully independent we can simply put each expert’s feedforward layer on a different worker. Compared to TP it’s much more lightweight, since we don’t need to split the matrix multiplication, we just need to route the hidden states of a token to the right expert. There are several tricks to make EP work in practice, closely tied to model design. For instance, DeepSeek-V3 enforces a constraint in the router, ensuring that each token is sent to at most M nodes (in their case, 4) to reduce communication overhead.</p> | |
| <p>Congratulation reader, we’ve now covered all current 4 directions of parallelism:</p> | |
| <ol> | |
| <li>Data Parallelism – along the batch dimension including ZeRO</li> | |
| <li>Tensor Parallelism - along the hidden-state dimension</li> | |
| <li>Sequence and Context Parallelism - along the sequence dimension</li> | |
| <li>Pipeline Parallelism - along the model layers</li> | |
| <li>Expert Parallelism - along the model experts</li> | |
| </ol> | |
| <p>One aspect you are maybe curious right now: how ZeRO-3 and TP/PP (which both shard model weights) compare to each other. Let’s look at the similarities and interplay!</p> | |
| <h1>5D parallelism in a nutshell</h1> | |
| <p>Let’s start with Pipeline parallelism as ZeRO-3 and Pipeline parallelism have interesting similarities and differences. </p> | |
| <p>Both methods are ways to partition the model weights over several GPUs and perform communication/computation generally “along the model depth axis”. In the following we say “a layer” to simplify what should be in general called “a set of layer” (as the basis sharding unit of the model). This means in both cases the full layers are computed on device, as opposed to TP, where the layers are sharded for the computation. </p> | |
| <p>However, there are a few major differences between the two:</p> | |
| <p>| <strong>ZeRO-3</strong> | <strong>Pipeline parallel</strong> | | |
| | --- | --- | | |
| | each compute unit only stores a fraction of a layer | each compute unit stores a full layer | | |
| | communication is used to transfer weights | communication is used to transfer activations | | |
| | model agnostic orchestration | complex orchestration to maintain PP efficiency |</p> | |
| <p>Clearly ZeRO-3 and PP are distinctly different approaches to sharing the model layers and deciding to focus communication either on weights or on activations. They can be combined but this is generally not very interesting as they propose similar ways to fixes the core symptom of saving parameter memory. If combined, ZeRO-3 should be configured to keep the weights in memory for each micro-batch in PP to at least avoid too much communication overhead.</p> | |
| <p>Note that ZeRO-1 and ZeRO-2 on the other hand are interesting to combine with Pipeline Parallelism as they focus on gradients and optimizer states instead of parameters and are thus complementary. For instance, DeepSeek-v3 used PP with ZeRO-1!</p> | |
| <p>In contrast to Pipeline Parallelism, Tensor Parallelism is naturally complementary and interoperable with ZeRO-3. For instance, if a model’s submodules (e.g. layers or layer block) is too large to fit in a GPU when rematerialised by ZeRO-3, there is no other obvious choice than to perform partial local operation and use Tensor Parallelism for this block/sub-model, combined with other dimension of parallelism like ZeRO-3 or PP as we saw above.</p> | |
| <p>Combining ZeRO-3 and TP doesn’t raise any specific issues except how to organize the GPU in groups for each parallelism dimension. As detailed above, TP will typically be kept for high-speed intra-node communications while ZeRO-3 can use parallelism groups spanning lower speed inter-node communications as the overlap with computation is easy to perform.</p> | |
| <h1>How to Find the Best Training Configuration</h1> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2055.png" /></p> | |
| <p>We’ve now covered all the parallelism techniques that are actually used to distribute and training larger models. There remain a general question: which ones should we choose and which ones are best combined? We touched a little bit on this at the end of the last section but in this section we will walk through the decision process step by step.</p> | |
| <p>| <strong>Method</strong> | <strong>Memory savings</strong> | <strong>Parallel/sharding dimension</strong> | <strong>Disadvantage</strong> | | |
| | --- | --- | --- | --- | | |
| | DP | Activations | Batch | Limited by max batch size | | |
| | TP/SP | Intra-layer activations and weights | Hidden dimension / Sequence length | Requires high bandwidth communication | | |
| | PP | Inter-layer activations and weights | Model layers | Idle bubble and complex schedules | | |
| | CP | Attention activations | Sequence length | Additional communication | | |
| | ZeRO-1 | Sharding the optimizer states | Optimizer states | | | |
| | ZeRO-2 | Sharding the optimizer states and gradients | Optimizer states and gradients | | | |
| | ZeRO-3 | Sharding the optimizers states, gradients, and model parameters | Optimizer states, gradients, and model weights | |</p> | |
| <p>So why are most pretraining runs relying on 4D parallelism instead of ZeRO, especially since ZeRO is model agnostic and thus easier to integrate? The reason is it’s tight coupling to DP: as mentioned earlier there is an upper limit on the batch size which is usually in the 4M tokens region. Let’s say you have 1024 GPUs and want to train with a sequence length of 8096. With pure DP this would mean a batch size of 0.5 per GPU which is not feasible with DP alone which requires at least 1. Additionally even a batch size of 1 means we’ll have to communicate after each sample which puts a strain on the network bandwidth.</p> | |
| <p>On the other hand by using PP or TP we can reduce the DP rank significantly, increase the local batch size and thus reduce the necessary global communication. In general, it is natural to combine 4D parallelism with at least ZeRO-1/2 to save optimizer and gradient memory and keep the number of PP stages or TP ranks under control.</p> | |
| <p>Overall, most training schedule past a certain size of the models wil tend to combine several of these methods.</p> | |
| <p>Let’s try synthesize the decision process into a relatively simple tree structure: </p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2056.png" /></p> | |
| <p>To explain briefly, data parallelism is the most efficient method, and you should always prioritize it when memory is not a concern. If communication is not a concern and you can keep the BS/GPU at a big enough value to make good use of the GPU MatMul, ZeRO is an easy method to remove memory bottlenecks and stay close to a simple DP implementation. However, on larger clusters you’ll probably be able to make efficient use for more 4D parallelism. In this case, starting with tensor parallelism is the most direct way to reduce memory usage and is generally faster than pipeline parallelism within a single node(8 GPUs). However, in scenarios with long contexts, the primary memory usage will tend to shifts from model weights, gradients, and optimizer states to activation values. In such cases, context parallelism becomes more beneficial than pipeline parallelism. Note that this is not an exact recipe and you should think of this more as a starting point of hyperparameters to run your own benchmarks. For instance sometimes TP mixed with PP can be more efficient, even if TP<8 and ZeRO-1/2 can make sense to mix in with 4D parallelism as well. </p> | |
| <p>This concludes our very deep dive into the distribution methods of 4D parallelism and ZeRO. However, besides scaling our model efficiently across GPUs there is another way to improve model throughput and memory management. </p> | |
| <p>Time to turn the lights off and activate CUDA mode! </p> | |
| <h1>Diving in the GPUs – fusing, threading, mixing</h1> | |
| <p>Up to now our discussion has been focused on the high-level organization of our model operations. We’ve moved around computations on various accelerators, taking into account general memory constraints and high-level scheduling of the compute units.</p> | |
| <p>But this ignored all the optimizations we can do at a much lower level by carefully understanding how our model operations are scheduled and performed on each GPU.</p> | |
| <p>This section will dive into much more details of the GPU architecture and in particular in NVIDIA’s GPU architecture but the general ideas, as often, can be reused on similar accelerator units.</p> | |
| <p>We’ll briefly explain how GPU are organized before covering the Flash-Attention revolution, how to efficiently schedule workload on GPU and finally explain how various precisions can be efficiently used on GPU.</p> | |
| <h3>A primer on GPU</h3> | |
| <p>Generally, GPUs have a very hierarchical organization. In this primer we’ll keep the discussion at the concept levels that are necessary for the rest of our presentation.</p> | |
| <p>On the compute side, GPUs consist of an array of compute units called <strong>Streaming Multiprocessors</strong> (SM). Each SM contains and controls a set of streaming processors, also known as cores. For example, an Nvidia H100 GPU has 132 SMs with 128 cores per SM, resulting in a total of 16,896 cores (see <a href="https://resources.nvidia.com/en-us-tensor-core">https://resources.nvidia.com/en-us-tensor-core</a> for details), each capable of handling multiple threads simultaneously.</p> | |
| <p><img alt="Original figure from https://blog.codingconfessions.com/p/gpu-computing." src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2057.png" /></p> | |
| <p>Original figure from <a href="https://blog.codingconfessions.com/p/gpu-computing">https://blog.codingconfessions.com/p/gpu-computing</a>.</p> | |
| <p>The memory side is also highly hierarchical with several layers of cache and memory: <strong>Registers</strong> are the smallest units and are private to the threads during executions, <strong>Shared Memory</strong> and <strong>L1 cache are</strong> shared between the threads running on a single SM, higher up is the <strong>L2 cache</strong> shared by all SMs, finally there is the <strong>Global Memory</strong> which is the largest memory on the GPU (the advertised 80 GB for a H100 for instance) but also the slowest to access and query. </p> | |
| <p><img alt="Original figure from https://www.youtube.com/watch?v=ZQKMZIP3Fzg" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2058.png" /></p> | |
| <p>Original figure from <a href="https://www.youtube.com/watch?v=ZQKMZIP3Fzg">https://www.youtube.com/watch?v=ZQKMZIP3Fzg</a></p> | |
| <p>The goal of GPU will be to run as many workloads as possible, in parallel, on the GPU cores, by taking advantage of this hierarchical organization of compute/memory.</p> | |
| <p>A piece of code running on a core of the GPU is called a <strong>kernel</strong>. It can be written at a high-level in <strong>CUDA</strong> or <strong>Triton</strong> for instance, and is then compiled to Parallel Thread Execution, PTX, the low-level assembly used by NVIDIA GPUs.</p> | |
| <p>To run the kernel, you will also need a specific code part (called <strong>host code</strong>) which is executed on the <strong>CPU</strong>/host and will take care of preparing data allocations and loading data and code.</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2059.png" /></p> | |
| <p>Figure 5: Host code for a CUDA kernel for adding two vectors from <a href="https://blog.codingconfessions.com/p/gpu-computing">https://blog.codingconfessions.com/p/gpu-computing</a></p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2060.png" /></p> | |
| <p>Figure 6: Device code containing the definition of the vector addition kernel from <a href="https://blog.codingconfessions.com/p/gpu-computing">https://blog.codingconfessions.com/p/gpu-computing</a></p> | |
| <p>Kernels are generally scheduled as follow:</p> | |
| <ul> | |
| <li>First the <strong>host code</strong> (executed on the <strong>CPU</strong>/host) takes care of preparing all data allocations and loading data and kernels</li> | |
| <li>When ready, SMs schedule the kernels to run in hierarchical organization as well:<ul> | |
| <li>threads are grouped in <strong>warps</strong> of sizes of 32. All the threads in a warp are synchronized to execute instructions simultaneously but on different parts of the data.</li> | |
| <li><strong>warps</strong> are grouped in larger <strong>blocks</strong> of more flexible size (e.g. size 256), each block still being assigned to a single SM. An SM may run several blocks in parallel, however, depending on the resources, not all the blocks may get assigned for execution immediately, some can be waitlisted waiting for resources.</li> | |
| </ul> | |
| </li> | |
| </ul> | |
| <p>The main thing to remember from these details is that there are various sizing and allocation constraints (size of the various memories, number of concurrent block and threads in the wraps) which need to be taken into account to use the GPU architecture in the most efficient way.</p> | |
| <p>Most of the time you don’t need to go down to this level of precision and you can luckily reuse the kernels and code prepared by other members of the community. But in any case we want to give you a primer on how to get started with kernels! </p> | |
| <h2>How to improve performance with Kernels ?</h2> | |
| <p>If you’re looking to add a new operation that lacks an optimized kernel or to speed up an existing PyTorch function, writing kernels from scratch might seem like the most direct route. However, creating high-performance CUDA kernels from scratch requires extensive experience and a steep learning curve. Generally a better way to get started is to leverage <code>torch.compile</code>, which dynamically optimizes PyTorch code by capturing your operations and generating lower-level, high-performance kernels in triton.</p> | |
| <p>Let’s suppose you want to write a kernel for an activation function called Exponential Linear Unit :</p> | |
| <p>$$ | |
| \text{ELU}(x) = | |
| \begin{cases} | |
| \ e^x - 1 & \text{if } x < 0 \ | |
| x & \text{if } x \geq 0 | |
| \end{cases} | |
| $$</p> | |
| <p>You can start by a simple pytorch implementation and then just add the <code>@torch.compile</code> decorator on top:</p> | |
| <p><code>python | |
| @torch.compile | |
| def elu(x, alpha=1.0): | |
| return torch.where(x < 0, alpha * (torch.exp(x) - 1), x)</code></p> | |
| <p>The distinction between the compiled and non-compiled versions is striking, especially given that we only added a single decorator. This remarkable difference is illustrated in the graph below (N is the number of columns) :</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2061.png" /></p> | |
| <p>However, if this performance increase is insufficient, you can consider implementing Triton kernels. As a starting point, you can take a look at the triton kernel generated by <code>@torch.compile</code> . To do so, you simply need to set the environment variable <code>TORCH_LOGS</code> to “output_code” : </p> | |
| <p><code>bash | |
| export TORCH_LOGS="output_code"</code></p> | |
| <p>Once you run the Python script with the <code>@torch.compile</code> decorator, it will generate and output the corresponding Triton kernel, which, in this case, is:</p> | |
| <p><code>python | |
| @triton.jit | |
| def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
| xnumel = 100000000 | |
| xoffset = tl.program_id(0) * XBLOCK | |
| xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
| xmask = xindex < xnumel | |
| x0 = xindex | |
| tmp0 = tl.load(in_ptr0 + (x0), xmask) | |
| tmp1 = 0.0 | |
| tmp2 = tmp0 < tmp1 | |
| tmp3 = tl_math.exp(tmp0) | |
| tmp4 = 1.0 | |
| tmp5 = tmp3 - tmp4 | |
| tmp6 = tl.where(tmp2, tmp5, tmp0) | |
| tl.store(out_ptr0 + (x0), tmp6, xmask)</code></p> | |
| <p>To enhance readability, we can modify the variable names, add comments, and make slight adjustments, as demonstrated below:</p> | |
| <p>```python | |
| @triton.jit | |
| def elu_kernel(input_ptr, output_ptr, num_elements, BLOCK_SIZE: tl.constexpr): | |
| # Calculate the starting index for this block | |
| block_start = tl.program_id(0) * BLOCK_SIZE | |
| # Create an array of indices for this block | |
| block_indices = block_start + tl.arange(0, BLOCK_SIZE)[:] | |
| # Create a mask to ensure only valid indices are processed | |
| valid_mask = block_indices < num_elements | |
| # Load input values from the input pointer based on valid indices | |
| input_values = tl.load(input_ptr + block_indices, valid_mask) | |
| # Define the ELU parameters | |
| zero_value = 0.0 # Threshold for ELU activation | |
| negative_mask = input_values < zero_value | |
| exp_values = tl.math.exp(input_values) | |
| # Define the ELU output shift | |
| one_value = 1.0 | |
| shifted_exp_values = exp_values - one_value</p> | |
| <pre><code>output_values = tl.where(negative_mask, shifted_exp_values, input_values) | |
| # Store the computed output values back to the output pointer | |
| tl.store(output_ptr + block_indices, output_values, valid_mask) | |
| </code></pre> | |
| <p>```</p> | |
| <p>Here, <code>tl.program_id(0)</code> provides a unique block ID, that we use to determine which section of data that block will process. Using this block ID, <code>block_start</code> calculates the starting index for each block’s section, while <code>block_indices</code> specifies the range of indices within that section. A <code>valid_mask</code> ensures that only indices within <code>num_elements</code> are processed, safely loading the data with <code>tl.load</code>. The ELU function is then applied, modifying values based on whether they're negative, and results are written back to memory with <a href="http://tl.store"><code>tl.store</code></a> .</p> | |
| <p>When we benchmark the generated kernel using <code>triton.testing.Benchmark</code> we have the following performance : </p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2062.png" /></p> | |
| <p>This standalone kernel demonstrates superior performance with smaller sizes compared to <code>@torch.compile</code> but this is likely here just an artifact from the compilation time of torch. compile. In any case, instead of starting from scratch, we can focus on optimizing this generated kernel, saving us time in the process. </p> | |
| <p>However, in Triton, sometimes, we cannot fully achieve the peak performance of the device due to limitations in handling shared memory and scheduling within streaming multiprocessors (SMs). Our access is restricted to blocks, allowing us only to manage the scheduling of blocks across SMs. To gain even more control, we will need to implement kernels in CUDA, where we have access to all the underlying components. </p> | |
| <p>In CUDA, there are various techniques that can be employed to make kernels more efficient; we will present just a few. These include optimizing memory access patterns to reduce latency, using shared memory to store frequently accessed data, and managing thread workloads to minimize idle times. In summary, the tools for writing code to execute instructions on the GPU are:</p> | |
| <ul> | |
| <li>Pytorch: easy but slow</li> | |
| <li>torch.compile: easy, fast, but not flexible</li> | |
| <li>triton: harder, faster, and more flexible</li> | |
| <li>CUDA: hardest, fastest, and flexiblest (if you get it right)</li> | |
| </ul> | |
| <p>Let’s talk about one of the most frequent technique we can use: optimizing memory access. The global memory in GPUs (the largest memory in our above graph) has a long latency and low bandwidth in comparision to the cache which often creates a major bottleneck for most applications. Efficiently accessing data from global memory can improve a lot the performance.</p> | |
| <h3>Memory Coalescing</h3> | |
| <p>To effectively utilize the bandwidth of global memory, it is essential to understand its architecture. In CUDA devices, global memory is implemented using DRAM.</p> | |
| <p>Memory coalescing takes advantage of how DRAM delivers data in bursts, or ranges of consecutive memory locations, whenever a memory address is accessed. Each time a DRAM location is accessed, a sequence of consecutive locations, including the requested one, is read in parallel by multiple sensors in the DRAM chip. Once read, this data can then be quickly transferred to the processor as a burst. In CUDA, coalescing uses this burst behavior to maximize memory access efficiency by ensuring that threads in a warp—32 threads that execute the same instruction in lockstep (SIMD)—access consecutive memory locations. For instance, if thread 0 accesses location M, thread 1 accesses M + 1, thread 2 accesses M + 2, and so forth, the GPU hardware coalesces or combines these requests into one large, efficient access request for the DRAM burst, rather than handling each access individually. </p> | |
| <p>Let’s take the example of matrix multiplication. A simple, straightforward implementation would have each thread compute a single element of the output matrix, like this:</p> | |
| <p>```cpp | |
| <strong>global</strong> void matmul_naive(int M, int N, int K, const float <em>A, const float </em>B, float *C) { | |
| const uint x = blockIdx.x * blockDim.x + threadIdx.x; | |
| const uint y = blockIdx.y * blockDim.y + threadIdx.y;</p> | |
| <pre><code>if (x < M && y < N) { | |
| float tmp = 0.0; | |
| for (int i = 0; i < K; ++i) { | |
| tmp += A[x * K + i] * B[i * N + y]; | |
| } | |
| C[x * N + y] = tmp; | |
| } | |
| </code></pre> | |
| <p>} | |
| ```</p> | |
| <p>Here’s an excellent visualization of the kernel from this fantastic <a href="https://siboehm.com/articles/22/CUDA-MMM">blogpost</a> : </p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2063.png" /></p> | |
| <p>However, when profiling this kernel with a tool like <code>ncu</code>, we can see issues, including low memory throughput and uncoalesced memory accesses.</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2064.png" /></p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2065.png" /></p> | |
| <p>The reason for this is that in this kernel, two threads in the same block with Thread IDs <code>(0, 0)</code> and <code>(1, 0)</code> (which will end up in the same warp) will both load from the same column of matrix <code>B</code> but different rows of matrix <code>A</code>. Since matrix elements are stored in row-major order (meaning each row's elements are in consecutive memory addresses, as shown in the figure below), in the first iteration with <code>i = 0</code>, thread <code>(0, 0)</code> will load $A_{0,0}$, and thread <code>(1, 0)</code> will load $A_{1,0}$. These elements are not stored close to each other in memory, and this misalignment repeats across all iterations along the shared dimension, preventing memory accesses from being coalesced.</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2066.png" /></p> | |
| <p>To improve our kernel we can change the way the coordinates x and y are calculated like the following : </p> | |
| <p>```cpp | |
| const int x = blockIdx.x * BLOCKSIZE + (threadIdx.x / BLOCKSIZE); | |
| const int y = blockIdx.y * BLOCKSIZE + (threadIdx.x % BLOCKSIZE);</p> | |
| <p>if (x < M && y < N) { | |
| float tmp = 0.0; | |
| for (int i = 0; i < K; ++i) { | |
| tmp += A[x * K + i] * B[i * N + y]; | |
| } | |
| C[x * N + y] = tmp; | |
| } | |
| ```</p> | |
| <p>Instead of using a 2D block, we switch to a 1D block and redefine how we determine the values of <code>x</code> and <code>y</code>. In this new method, threads within the same warp (which have close <code>threadIdx.x</code> values) will share the same <code>x</code> value but have different <code>y</code> values. This means that they will load the same row of matrix <code>A</code> but different columns of matrix <code>B</code>. As a result, memory accesses can be coalesced for a row-major matrix.</p> | |
| <p>When we profile our new kernel, we notice that the warning about uncoalesced memory accesses has disappeared, and <strong>the GPU's memory throughput has increased by approximately 10 times</strong>.</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2067.png" /></p> | |
| <p>We also notice that the execution time of the kernel <strong>decreases by 10x</strong> !</p> | |
| <p>Let’s cover another technique you will often see mentioned in the litterature: tiling.</p> | |
| <h3>Tiling</h3> | |
| <p>Tiling is a technique that leverages <em>shared memory</em> to optimize memory access patterns. As we mentioned above, the shared memory is a small, fast memory accessible by all threads within a block. It allows data to be reused by multiple threads, reducing the need to repeatedly load data from slower global memory.</p> | |
| <p>In matrix multiplication for example, each thread in a block may need elements from two matrices, say A and B. If each thread independently loads the row and column it needs from global memory, we end up with many redundant loads, as multiple threads in a block will access overlapping data. Instead, we can use tiling to load a block (or tile) of A and B into shared memory just once, allowing all threads in that block to reuse the same shared data.</p> | |
| <p>In the tiling approach, each iteration involves all threads within a block cooperatively loading two tiles—one from matrix A and another from matrix B —into shared memory. Specifically, threads load a tile of matrix A (of size <code>BLOCK_SIZE_M</code> by <code>BLOCK_SIZE_K</code>) and a tile of matrix B (of size <code>BLOCK_SIZE_K</code> by <code>BLOCK_SIZE_N</code>). Once the tiles are in shared memory, the threads perform matrix multiplication on these tiles, enabling efficient computation since all necessary data is quickly accessible. The results of the tile multiplication are stored in an accumulation matrix that holds intermediate results. After each iteration, the results from the current tile multiplication are added to this accumulation matrix, continuing until all tiles from both matrices have been processed. </p> | |
| <p><img alt="From https://cnugteren.github.io/tutorial/pages/page4.html" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2068.png" /></p> | |
| <p>From <a href="https://cnugteren.github.io/tutorial/pages/page4.html">https://cnugteren.github.io/tutorial/pages/page4.html</a></p> | |
| <p>The important parts to understand the implementation are below (for simplicity we consider a square shaped tile) : </p> | |
| <p>```cpp</p> | |
| <p>// Set pointers to the starting elements | |
| A += blockRow * TILE_SIZE * K; // Start at row = blockRow, column = 0 | |
| B += blockCol * TILE_SIZE; // Start at row = 0, column = blockCol | |
| C += blockRow * TILE_SIZE * N + blockCol * TILE_SIZE; // Start at row = blockRow, column = blockCol | |
| float sum = 0.0; | |
| // The outer loop moves through tiles of A (across columns) and B (down rows) | |
| for (int tileIdx = 0; tileIdx < K; tileIdx += TILE_SIZE) { | |
| sharedA[localRow * TILE_SIZE + localCol] = A[localRow * K + localCol]; | |
| sharedB[localRow * TILE_SIZE + localCol] = B[localRow * N + localCol];</p> | |
| <p>// Ensure all threads in the block have completed data loading | |
| __syncthreads();</p> | |
| <p>// Shift pointers to the next tile | |
| A += TILE_SIZE; | |
| B += TILE_SIZE * N;</p> | |
| <p>// Compute the partial dot product for this tile | |
| for (int i = 0; i < TILE_SIZE; ++i) { | |
| sum += sharedA[localRow * TILE_SIZE + i] * sharedB[i * TILE_SIZE + localCol]; | |
| } | |
| // Synchronize again to prevent any thread from loading new data | |
| // into shared memory before others have completed their calculations | |
| __syncthreads(); | |
| } | |
| C[localRow * N + localCol] = sum; | |
| ```</p> | |
| <p>Each thread begins by loading one element from both <strong>Matrix A</strong> and <strong>Matrix B</strong> into shared memory. In this scenario, achieving coalesced memory access is straightforward, by assigning <code>threadIdx.x</code> as the <strong>local column index (localCol)</strong>, threads within the same warp will access adjacent elements of both matrices. After each thread in the block completes loading its elements into shared memory (ensured by calling <code>__syncthreads()</code>), they proceed to compute the dot product of the two tiles. Once the threads have iterated through all the tiles—horizontally for <strong>Matrix A</strong> and vertically for <strong>Matrix B</strong>—the resulting sum is stored in the corresponding location of <strong>Matrix C</strong>.</p> | |
| <p>When benchmarking this kernel using ncu, we noticed that the memory throughput increased to 410 Gb / s, and the kernel execution time decreased by ~43% achieving a ~6.6 TFLOPs performance</p> | |
| <h3>Thread Coarsening</h3> | |
| <p>The tiling technique has significantly improved the performance of our kernel. However, when analyzing the warp states which quantify how many cycles were spent in each state, we observe the following:</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2069.png" /></p> | |
| <p>The meaning of the states can be found in the <a href="https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference">Profiling Guide</a>, specifically in the <strong>Warp Stall Reasons</strong> section. There we can read that : </p> | |
| <blockquote> | |
| <p>smsp__pcsamp_warps_issue_stalled_mio_throttle : Warp was stalled waiting for the MIO (memory input/output) instruction queue to be not full. This stall reason is high in cases of extreme utilization of the MIO pipelines, which include special math instructions, dynamic branches, as well as shared memory instructions. When caused by shared memory accesses, trying to use fewer but wider loads can reduce pipeline pressure. | |
| </p> | |
| </blockquote> | |
| <p>So it seems warps are stalling waiting for shared memory accesses to return ! To resolve this issue we can apply the <strong>Thread Coarsening</strong> technique by merging several threads into a single coarsened thread, we can significantly reduce shared memory accesses because each coarsened thread can handle multiple output elements which would increase the arithmetic intensity of the kernel.</p> | |
| <h3>Minimizing Control Divergence</h3> | |
| <p>A Streaming Multiprocessor (SM) is built to execute all threads in a warp using the Single Instruction, Multiple Data (SIMD) model. This means that at any given moment, one instruction is fetched and executed simultaneously for all threads within the warp. When a warp is executed, the threads within it operate on different segments of the data but follow the same instruction, hence the name Single Instruction, Multiple Data. The primary advantage of SIMD is its efficiency; the control hardware responsible for instruction fetching and dispatching is shared among multiple execution units. This design minimizes the hardware overhead associated with control functions, allowing a greater portion of the hardware to focus on improving arithmetic throughput.</p> | |
| <p>Control divergence occurs when threads within the same warp take different execution paths. For instance, if a conditional statement (like an <code>if</code> statement) leads to some threads executing one block of code while others execute a different block, the warp must serialize these executions, resulting in idle threads waiting for others to complete. </p> | |
| <p>To minimize control divergence, we need to design kernels to ensure that threads within the same warp follow the same execution path. This can be achieved by restructuring code to reduce branching, using data structures that ensure all threads follow similar execution paths, or employing techniques such as predication.</p> | |
| <p>We have covered some of the main considerations when writing custom kernels and improving the performance and memory footprint of GPU operations. We now have all the understanding necessary to marvel at a true masterpiece of kernel engineering: <strong><em>Flash Attention.</em></strong></p> | |
| <h2>Flash Attention 1-3</h2> | |
| <p>Flash attention is a technique pioneered by <a href="https://tridao.me">Tri Dao</a> that optimizes the attention computations by writing custom CUDA kernels to make it much faster <em>and</em> more memory efficient. The idea behind Flash Attention is to make efficient use of the various memories of the GPU to avoid using too much the slowest global memory of the GPU (confusingly called the High Bandwidth Memory, HBM 🫠) </p> | |
| <p>A basic implementation of the attention mechanism involve a lot of transfer between memory and workers. It requires materializing the S and P matrices in HBM which means that the results need to be sent to HBM and then back to SRAM for the next computations:</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2070.png" /></p> | |
| <p>Since bandwidth is much lower in HBM this introduces a severe bottleneck in the attention computation. Can we do better? Tri Dao says yes!</p> | |
| <p>The key element is to compute the S matrices in small pieces which can fit in the smaller shared memory of the SM. But we can do even better and avoid materializing the very large S matrix all together in favor of keeping only the necessary statistics for computing the normalization factor of the softmax. So we can compute part of $O$ directly in one computation in SRAM rather than moving intermediate results back and forth. In this case, not even do we make use of the shared memory but we also release the memory bottleneck resulting from materializing one of the largest activation matrices in the model (at long context length), the attention matrix.</p> | |
| <p><img alt="From the FLASH-ATTENTION paper (https://arxiv.org/pdf/2205.14135)" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2071.png" /></p> | |
| <p>From the FLASH-ATTENTION paper (<a href="https://arxiv.org/pdf/2205.14135">https://arxiv.org/pdf/2205.14135</a>)</p> | |
| <p>The idea of flash attention resolves so many bottlenecks in model training that it has quickly become the default way to perform attention in all transformers:</p> | |
| <ul> | |
| <li>By avoiding to materialize the S matrix we <strong>reduce the memory burden</strong> of attention</li> | |
| <li>We also remove a large part of the <strong>naive impact of the S^2 cost of attention</strong></li> | |
| </ul> | |
| <p>As a result as well, all variants of linear attention and sub-quadratic approaches to approximate attention –developed shortly after the invention of the transformers architecture– have been mostly put aside in favor of this exact and fast flash attention implementation and mechanism.</p> | |
| <p>Following Flash-attention 1, two successive improved versions have been released by the same lab: Flash-attention 2 and 3. In comparison to Flash-attention 1, the improvements in Flash-attention 2 and 3 are less about the general attention mechanism than about tailoring its low level implementation more specifically to the GPU by (1) reducing the number of non-matmul operations as much as possible (2) partitioning carefully the workload among wraps and thread blocks (for Flash Attention 2) and carefully optimizing for FP8 and Tensor Core support on the latest Hopper (H100) architecture for Flash Attention 3.</p> | |
| <p>TODO: <a href="https://pytorch.org/blog/flexattention/">https://pytorch.org/blog/flexattention/</a></p> | |
| <p>Flash-Attention is a master demonstration of the breakthrough improvements that can come when you take into account the internal memory/compute design of current GPU accelerators.</p> | |
| <p>However another general improvement –widely performed by many AI practitioners– is to reduce as much as possible the time spent waiting for CPU and GPU synchronization. This is done using a technique called “fused kernels”.</p> | |
| <h2>Fused Kernels</h2> | |
| <p>In several places now we’ve mentioned how GPU and CPU operation can be asynchronous. In particular, the host code on the CPU can schedule workload on the GPU in a non-blocking way.</p> | |
| <p>Non-blocking can be useful for overlapping communication and computation as we saw at several part along this blog post but can be extended to the more general idea of trying to avoid at all cost going back and forth between host and GPU kernel commands.</p> | |
| <p>How can we avoid this back and forth? Well the best way is to make our GPU as autonomous as possible.</p> | |
| <p>This is achieved by packing as many successive compute operations together in a single kernel for the GPU to run, called a “Fused Kernel”.</p> | |
| <p>Fused kernel are especially efficient and simple to write for succession of point-like operations which are performed independently of each other on each input tokens. In this case, there is no point in bringing back computed values in Global Memory before moving them to SM memory and spinning up a new kernel. It’s much more efficient to keep all values local until the succession of computation has been performed.</p> | |
| <p>What are many places in a Transformer model were this can be advantageous, for instance when. a succession of point-wise operations is performed, e.g. in the computation involved in the Layer norms.</p> | |
| <h2>Mixed Precision Training</h2> | |
| <p>The techniques described so far in this section require specific modeling code changes and writing custom kernels for certain operations in order to speed up training. In this section we take a look at a range of methods that are agnostic to the modeling code and can be used for any model.</p> | |
| <p>The default numerical precision of PyTorch tensors is single-precision floating point format or also called FP32 or float32 which means that every number stored takes up 32 bits or 4 bytes. The available bits to represent a number are divided into 3 parts:</p> | |
| <ul> | |
| <li>Sign: the first bit determines if the number is positive or negative</li> | |
| <li>Mantissa: determines the significant figures of a number</li> | |
| <li>Exponent: controls the magnitude of the number</li> | |
| </ul> | |
| <p>The principle of floating point numbers can be easily illustrated by recalling the scientific notation of numbers, e.g. $- 5.734 \times 10^{7}$, where we first have the sign, followed by the mantissa an the exponent. As such we can represent numbers across a wide range of magnitudes with an adaptive precision. Although float32 is the default there is a range of floating point formats available in PyTorch:</p> | |
| <p>| <strong>Format</strong> | <strong>Total bits</strong> | <strong>Sign</strong> | <strong>Mantissa</strong> | <strong>Exponent</strong> | | |
| | --- | --- | --- | --- | --- | | |
| | float32 | 32 | 1 | 23 | 8 | | |
| | float16 | 16 | 1 | 10 | 5 | | |
| | bfloat16 | 16 | 1 | 7 | 8 | | |
| | float8 (e4m3) | 8 | 1 | 3 | 4 | | |
| | float8 (e5m2) | 8 | 1 | 2 | 5 |</p> | |
| <blockquote> | |
| <p>Note: You might be wondering where the “b” in bfloat16 comes from. The format was developed at Google Brain and thus the “b” stands for “brain”. | |
| </p> | |
| </blockquote> | |
| <p>Reducing the total number of bits comes at a price (no free lunch here either), but we have some control over how to pay. Either we can sacrifice more bits on the mantissa or exponent. For this reason there exist also two float8 formats, named according to exponent and mantissa, to flexibly choose the most appropriate format. We can look at the possible range of numbers for each format:</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2072.png" /></p> | |
| <p>We can see that float32 spans 80 orders of magnitude and float16 sacrifices a lot of range while bfloat16 maintains the full range. The two float8 formats reduce the range even further where e5e2 can maintain float16 range and e4m3 has an even smaller ranger.</p> | |
| <p>How come some format are able to maintain the range and other not? Let’s investigate the resolution by plotting 10,000 points between 1 and 2. Each point will be rounded to the nearest representable number in each format:</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2073.png" /></p> | |
| <p>We can see here that bfloat16 maintained the range of float32 over float16 but did this with the cost of sacrificing more precision. In case of float8 the situation is even more dire as e4m3 can represent 7 and e5m2 only 3 number on the interval 1-2.</p> | |
| <p>A common metric to measure a formats resolution is epsilon: the first representable number after 1.00. We can see that for the float32 format $10^{-4}$ is an upper bound (it’s actually $1.19^{-7}$). For float16 it is \tilde 10^{-3} and for bfloat 10x higher still. </p> | |
| <p>The idea of mixed precision training is to use some of these lower precisions formats while maintaining the performance of full precision training. It turns out we <strong>can’t</strong> totally abandon float32 and usually will need to maintain some parts in full precision. </p> | |
| <p>This is why lower precision training is usually called <strong><em>mixed</em></strong> precision training. </p> | |
| <p>Let’s now take a look at training models with 16 bits and then see if we can take it a step further all the way down to 8 bits.</p> | |
| <h3>FP16 and BF16 training</h3> | |
| <p>Naively switching all the tensors and operations to float16 unfortunately doesn’t work and the result is usually diverging losses. However, the original mixed precision training paper (see <a href="https://arxiv.org/pdf/1710.03740">https://arxiv.org/pdf/1710.03740</a>) came up with three tricks to match float32 trainings:</p> | |
| <ol> | |
| <li><strong>FP32 copy of weights:</strong> There are two possible issues with float16 weights. During training some of the weights can become very small and will be rounded to 0. However, even if the weights themselves are not close to zero, if the updates are very small the difference in magnitude can cause the weights to underflow during the addition. Once the weights are zero they will remain 0 for the rest of training as there is no gradient signal coming through anymore.</li> | |
| <li><strong>Loss scaling:</strong> We have a similar issue with the gradients as well as gradients tend to be much smaller than 1 and are thus at risk to underflow. A simple, yet effective, strategy is to scale the loss before the backward pass and unscale the gradients after the backward pass. This ensures that there is no underflow during the backward pass and the scaling is not affecting training as we unscale before processing the gradients further (e.g. clipping) and the optimization step. </li> | |
| <li><strong>Accumulation:</strong> Finally, when performing arithmetic operations in float16 such as in dot products, we can also face under or overflows. Does targeting certain types of arithmetic operations to accumulate the intermediate results in float32 during the operation and then casting the accumulated result back to fp16. For the same reason gradients are also accumulated in float32.</li> | |
| </ol> | |
| <p>With these techniques, you get consistently stable training while benefitting from higher throughput due to the faster, lower precision operations. Naturally, as the curious reader you are and by now slightly addicted to maximizing the throughput, you ask the question: can we go further and faster? </p> | |
| <p>Maybe!</p> | |
| <h3>FP8 pretraining</h3> | |
| <p>Even if we perfectly overlap communication with computation, we always eventually run into the low level theoretical FLOPS limit of the hardware itself, i.e. the efficiency of each individual operation on our hardware. This is where numerical precision becomes crucial. For instance, on NVIDIA's H100 GPU, FP8 matrix multiplications (GEMM operations) achieve twice the theoretical FLOPS of bfloat16, making lower-precision training an attractive path for further optimization.</p> | |
| <p>Recent research - including <a href="https://arxiv.org/abs/2310.18313">FP8-LM</a>, <a href="https://github.com/pytorch/ao/tree/main/torchao/float8#torchaofloat8">torchao</a>, and <a href="https://arxiv.org/abs/2412.19437">DeepSeek-V3</a> - has demonstrated the potential of FP8 training for large-scale models. Still, FP8 pretraining introduces a significant challenge: stability. At lower precision, numerical instability often leads to loss divergence, making it difficult to match the accuracy of higher-precision training.</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2074.png" /></p> | |
| <p>As <a href="https://arxiv.org/abs/2309.14322">[Wortsman et al.]</a> observed, instability increases as learning rates rise for a fixed model size, making FP8 pretraining particularly tricky.</p> | |
| <p>The first, successful, very large scale training with FP8 mixed precision was publicly reported on DeepSeek-V3. The authors carefully analyzed each operation of the forward pass (Fprop) as well as the activation (Dgrad) and weight (Wgrad) backward pass. Similar to BF16 mixed precision training, some aggregation and master weights are kept in higher precision while the operations themselves are performed in FP8. </p> | |
| <p><img alt="Screenshot 2025-02-09 at 22.20.28.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/Screenshot_2025-02-09_at_22.20.28.png" /></p> | |
| <p>In order to switch from high precision (e.g. FP32 or BF16) to lower precision (e.g. FP16 or FP8) with smaller range, we need to normalize the range of values by computing the absolute maximum. DeepSeek-V3 also introduces a quantization scheme, where the ranges are normalized per tile: 1x128 for inputs/activations and 128x128 for weights and scale elements. This makes the normalization less susceptible to outliers. There is a number of additional tricks they deploy to also reduce the memory and communication footprint which you can follow in <a href="https://arxiv.org/pdf/2412.19437">section 3.3. of the DeepSeek-V3 technical report</a>. </p> | |
| <p>Here’s a summary of a few known approaches to FP8 training:</p> | |
| <p>| | GEMM | Master weight | Weight | Gradient | Optimizer States | Total memory | | |
| | --- | --- | --- | --- | --- | --- | --- | | |
| | bfloat16 with fp32 mixed precision baseline | bfloat16 | fp32 | bfloat16 | bfloat16 | float32 + float32 | 4 + 2 + 2 + 4 + 4 = 16 bytes | | |
| | Transformer Engine | fp8 | n/a | float32 | float32 | float32 + float32 | 4 + 4 + 4 + 4 = 16 bytes | | |
| | FP8-LM’s O3 | fp8 | float16 | fp8 | fp8 | fp8 + float16 | 2 + 1 + 1 + 1 + 2 = 7 bytes (56%) | | |
| | DeepSeek-V3 | fp8 | fp32 | fp8 | bf16/fp32 | bf16 | ? | | |
| | nanotron’s FP8 | fp8 | bfloat16 | fp8 | fp8 | fp8 + fp8 | 2 + 1 + 1 + 1 + 1 = 6 bytes (~50%) |</p> | |
| <p>Overall, FP8 is still an experimental technique and methods are evolving, but will likely become the standard soon replacing bfloat16 mixed-precision. To follow public implementations of this, please head to the nanotron’s implementation in [TODO: link to appendix]. </p> | |
| <p>In the future, Blackwell, the next generation of NVIDIA chips, have <a href="https://www.nvidia.com/en-us/data-center/technologies/blackwell-architecture/">been announced</a> to support FP4 training, further speeding up training but without a doubt also introducing a new training stability challenge.</p> | |
| <p>We now arrived at the end of the distributed training journey. Let’s take a step back and conclude.</p> | |
| <h1>Conclusion</h1> | |
| <p>We have come a long way, starting from training a small model on one GPU to learning how the largest models like Llama-405B or DeepSeek-V3 are trained efficiently. In distributed training, many concepts sound easy enough when you first hear them, like “Pipeline parallelism just distributes layers on different GPUs”, but we also worked through all the challenging details when implementing those methods. </p> | |
| <p>However, not only you learned something in the process and we want to share a few insights we gained along the way, as well as give some ideas what you can work on next if you want to gain more experience in distributed training.</p> | |
| <p>But let’s start with a brief recap of all the things we covered in these past hours and days!</p> | |
| <h2>What you learned</h2> | |
| <p>Working through this whole blog post you mastered a ranged of concepts:</p> | |
| <ul> | |
| <li>Basic principle of model training</li> | |
| <li>Collective communication primitives</li> | |
| <li>Memory anatomy of a LLM</li> | |
| <li>Distributed training with DP and ZeRO</li> | |
| <li>Model parallelism with TP, SP, CP and PP</li> | |
| <li>Fast kernels and mixed precision training</li> | |
| <li>Overlapping communication and computation</li> | |
| <li>Profiling distributed training</li> | |
| </ul> | |
| <p>Furthermore, you saw code implementations of most methods and how to benchmark a distributed training. But it hasn’t been only a learning experience for you, also we learned a thing or two!</p> | |
| <h2>What we learned</h2> | |
| <p>Benchmarks, profiling, bugs</p> | |
| <h2>What’s next?</h2> | |
| <p>You should have a good overview of all the distributed training concepts but there are still things to learn and details we couldn’t cover. To get deeper in the field we recommend doing some of the following steps:</p> | |
| <ul> | |
| <li>Carefully read some of the landmark or very recent papers. You can find a list of some of the most impactful papers in [TODO References]</li> | |
| <li>Start from scratch and implement an algorithm yourself. Often a method only fully “clicks” if you implemented it yourself.</li> | |
| <li>Dive into one of the widely used frameworks and start contributing: fix bugs, answer issues, or implement a new feature. That’s the best way to get in any ML field!</li> | |
| </ul> | |
| <p>We hope this blog helps you get started in distributed training or helps you to better understand methods that you may already be applying by using some distributed training frameworks.</p> | |
| <h1>References</h1> | |
| <h2>Landmark LLM Scaling Papers</h2> | |
| <p>Megatron-Turing NLG 530B: <a href="https://developer.nvidia.com/blog/using-deepspeed-and-megatron-to-train-megatron-turing-nlg-530b-the-worlds-largest-and-most-powerful-generative-language-model/">https://developer.nvidia.com/blog/using-deepspeed-and-megatron-to-train-megatron-turing-nlg-530b-the-worlds-largest-and-most-powerful-generative-language-model/</a></p> | |
| <p>PaLM: <a href="https://arxiv.org/abs/2204.02311">https://arxiv.org/abs/2204.02311</a></p> | |
| <p>Gemini: <a href="https://arxiv.org/abs/2312.11805">https://arxiv.org/abs/2312.11805</a></p> | |
| <h2>Training Frameworks</h2> | |
| <ul> | |
| <li>FairScale: https://github.com/facebookresearch/fairscale/tree/main</li> | |
| <li>Megatron-LM: https://github.com/NVIDIA/Megatron-LM</li> | |
| <li>DeepSpeed: https://www.deepspeed.ai/</li> | |
| <li>ColossalAI: https://colossalai.org/</li> | |
| <li>torchtitan: https://github.com/pytorch/torchtitan</li> | |
| <li>GPT-NeoX: <a href="https://github.com/EleutherAI/gpt-neox?tab=readme-ov-file#news">https://github.com/EleutherAI/gpt-neox</a></li> | |
| <li>LitGPT: https://github.com/Lightning-AI/litgpt</li> | |
| <li>DiLoco: <a href="https://github.com/PrimeIntellect-ai/OpenDiLoCo">https://github.com/PrimeIntellect-ai/OpenDiLoCo</a></li> | |
| </ul> | |
| <h2>Debugging</h2> | |
| <ul> | |
| <li>Speed profiling: <a href="https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html">https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html</a></li> | |
| <li>Memory profiling: <a href="https://pytorch.org/blog/understanding-gpu-memory-1/">https://pytorch.org/blog/understanding-gpu-memory-1/</a></li> | |
| <li><a href="https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html">https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html</a></li> | |
| </ul> | |
| <h2>Distribution Techniques</h2> | |
| <ul> | |
| <li>Data parellism: <a href="https://siboehm.com/articles/22/data-parallel-training">https://siboehm.com/articles/22/data-parallel-training</a></li> | |
| <li>ZeRO: <a href="https://arxiv.org/abs/1910.02054">https://arxiv.org/abs/1910.02054</a></li> | |
| <li>FSDP: <a href="https://arxiv.org/pdf/2304.11277">https://arxiv.org/pdf/2304.11277</a> <a href="https://pytorch.org/blog/efficient-large-scale-training-with-pytorch/">https://pytorch.org/blog/efficient-large-scale-training-with-pytorch/</a></li> | |
| <li>Tensor and Sequence Parallelism + Selective Recomputation: <a href="https://arxiv.org/abs/2205.05198">https://arxiv.org/abs/2205.05198</a></li> | |
| <li>Pipeline parallelism: <a href="https://developer.nvidia.com/blog/scaling-language-model-training-to-a-trillion-parameters-using-megatron/#pipeline_parallelism">https://developer.nvidia.com/blog/scaling-language-model-training-to-a-trillion-parameters-using-megatron/#pipeline_parallelism</a></li> | |
| <li>Breadth first Pipeline Parallelism: <a href="https://arxiv.org/abs/2211.05953">https://arxiv.org/abs/2211.05953</a></li> | |
| <li>All-reduce <a href="https://andrew.gibiansky.com/blog/machine-learning/baidu-allreduce/">https://andrew.gibiansky.com/blog/machine-learning/baidu-allreduce/</a></li> | |
| <li>Ring-flash-attention https://github.com/zhuzilin/ring-flash-attention, <a href="https://zhuanlan.zhihu.com/p/683714620">https://zhuanlan.zhihu.com/p/683714620</a></li> | |
| <li>Ring attention tutorial: <a href="https://coconut-mode.com/posts/ring-attention/">https://coconut-mode.com/posts/ring-attention/</a></li> | |
| <li>ZeRO and 3D: <a href="https://www.deepspeed.ai/tutorials/large-models-w-deepspeed/#understanding-performance-tradeoff-between-zero-and-3d-parallelism">https://www.deepspeed.ai/tutorials/large-models-w-deepspeed/#understanding-performance-tradeoff-between-zero-and-3d-parallelism</a></li> | |
| <li>Mixed precision training: <a href="https://arxiv.org/pdf/1710.03740">https://arxiv.org/pdf/1710.03740</a></li> | |
| </ul> | |
| <h2>CUDA Kernels</h2> | |
| <h2>Hardware</h2> | |
| <p>Fire-Flyer - a 10,000 PCI chips cluster: <a href="https://www.arxiv.org/abs/2408.14158">https://www.arxiv.org/abs/2408.14158</a></p> | |
| <p>Meta’s 24k H100 Pods: https://engineering.fb.com/2024/03/12/data-center-engineering/building-metas-genai-infrastructure/</p> | |
| <p>Semianalysis - 100k H100 cluster: <a href="https://www.semianalysis.com/p/100000-h100-clusters-power-network">https://www.semianalysis.com/p/100000-h100-clusters-power-network</a></p> | |
| <h2>Others</h2> | |
| <ul> | |
| <li>Stas Bekman’s Handbook: https://github.com/stas00/ml-engineering</li> | |
| <li>Bloom training chronicles: https://github.com/bigscience-workshop/bigscience/blob/master/train/tr11-176B-ml/chronicles.md</li> | |
| <li>OPT logbook: https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/chronicles/OPT175B_Logbook.pdf</li> | |
| <li>Scaling context: <a href="https://www.harmdevries.com/post/context-length/">https://www.harmdevries.com/post/context-length/</a></li> | |
| <li>Harm’s law for smol models <a href="https://www.harmdevries.com/post/model-size-vs-compute-overhead/">https://www.harmdevries.com/post/model-size-vs-compute-overhead/</a></li> | |
| </ul> | |
| <h1>Appendix</h1> | |
| <h2>A0: Parallel Programming Crash Course</h2> | |
| <p><strong>Resources:</strong></p> | |
| <ul> | |
| <li>https://github.com/NVIDIA/nccl/issues/256</li> | |
| <li><a href="https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html">https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html</a></li> | |
| <li><a href="https://en.wikipedia.org/wiki/Collective_operation">https://en.wikipedia.org/wiki/Collective_operation</a></li> | |
| <li><a href="https://developer.nvidia.com/blog/massively-scale-deep-learning-training-nccl-2-4/">https://developer.nvidia.com/blog/massively-scale-deep-learning-training-nccl-2-4/</a></li> | |
| <li><a href="https://www.mdpi.com/2076-3417/14/12/5100#:~:text=In%20a%20parameter%20server%20setup,architecture%20for%20ring%20all%2Dreduce">https://www.mdpi.com/2076-3417/14/12/5100#:~:text=In a parameter server setup,architecture for ring all-reduce</a>.</li> | |
| </ul> | |
| <p>Throughout this blogpost we’ll scale LLM training from one to hundreds of GPUs. This will require the communication and synchronization of weights, gradients, and data between all the machines. There’s a set of distributed patterns to achieve exactly that called <strong>collective operations</strong>. In this section we’ll do a small crash course of <em>Broadcast, AllReduce, Scatter</em> and co. but if you are already familiar with these patterns feel free to move directly to [SECTION I], otherwise let’s get ☕ #1 (or your neural stimulant of choice) and let’s dig in! </p> | |
| <p>The general setup is that we have a number of independent nodes which could be CPU cores, GPUs, or compute nodes. Each performs some computation and then we want to communicate the result or parts of it to the other nodes for the next computation step (t+1). </p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2075.png" /></p> | |
| <p>Maybe we need to send the result from one node to all other nodes, or we need to sum all the intermediate results from each node to report the overall result. Usually, there is one node with an elevated status that plays a central role, here denoted with <code>root</code> that is the target or source of some operations. Let’s start with one of the simplest primitives: a broadcast operation.</p> | |
| <h3>Broadcast</h3> | |
| <p>A very common pattern is that you have some data on Node 1 and you want to share it with all the other nodes so they can do some computation with the data. The broadcast operation does just that:</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2076.png" /></p> | |
| <p>Collective operations are natively provided by PyTorch so we can easily write a small example that demonstrates how broadcasting works. We first need to initialize a process group with <code>dist.initi_process_group</code> which sets up the communication backend (we’ll talk about NCCL later), it determines how many workers (aka nodes) exists and assigns a rank to each one (which we can get with <code>dist.get_rank</code>). Finally, it establishes a connection between the workers.</p> | |
| <p>To showcase the <code>broadcast</code> operation, let's create a tensor with non-zero values on <code>rank=0</code> and tensors full of zeros on the other workers. We then distribute the <code>rank=0</code> tensor to all other ranks with <code>dist.broadcast(tensor, src=0)</code> :</p> | |
| <p>```python | |
| import torch | |
| import torch.distributed as dist</p> | |
| <p>def init_process(): | |
| dist.init_process_group(backend='nccl') | |
| torch.cuda.set_device(dist.get_rank())</p> | |
| <p>def example_broadcast(): | |
| if dist.get_rank() == 0: | |
| tensor = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32).cuda() | |
| else: | |
| tensor = torch.zeros(5, dtype=torch.float32).cuda() | |
| print(f"Before broadcast on rank {dist.get_rank()}: {tensor}") | |
| dist.broadcast(tensor, src=0) | |
| print(f"After broadcast on rank {dist.get_rank()}: {tensor}")</p> | |
| <p>init_process() | |
| example_broadcats() | |
| ```</p> | |
| <p>You can run the above script with <code>torchrun --nproc_per_node=3 dist_op.py</code> (you’ll need 3 GPUs for this or change <code>nproc_per_node</code> accordingly) and you should see the following output:</p> | |
| <p>```python | |
| Before broadcast on rank 0: tensor([1., 2., 3., 4., 5.], device='cuda:0') | |
| Before broadcast on rank 1: tensor([0., 0., 0., 0., 0.], device='cuda:1') | |
| Before broadcast on rank 2: tensor([0., 0., 0., 0., 0.], device='cuda:2')</p> | |
| <p>After broadcast on rank 0: tensor([1., 2., 3., 4., 5.], device='cuda:0') | |
| After broadcast on rank 1: tensor([1., 2., 3., 4., 5.], device='cuda:1') | |
| After broadcast on rank 2: tensor([1., 2., 3., 4., 5.], device='cuda:2') | |
| ```</p> | |
| <p>Great, seems like it works as expected. Note that the rank messages can be printed out of order as we have no control over which print statement is executed first (we ordered them here for readability). Now let’s move on to the Reduce and AllReduce patterns! </p> | |
| <h3>Reduce & AllReduce</h3> | |
| <p>Reduce patterns are among the most fundamental patterns in distributed data processing. The idea is that you want to combine the data present on each node through a function <code>f()</code> which can be for instance summation or averaging. In the Reduce paradigm the result is sent to the root node only, whereas in the AllReduce case the result is broadcasted to all nodes:</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2077.png" /></p> | |
| <p>Of course no magic “free flying” node that can perform this operation and generally each node does a partial computation in a ring or tree structure of the nodes. Here is a simple example: let’s say we need to compute a sum of numbers on each nodes and our nodes are connected in a ring pattern. The first node sends its number to a neighbour which adds its number to the received number before forwarding it to the next neighbour. At the end of a round along the ring of nodes, the first node will receive the total sum.</p> | |
| <p>Here’s the code to run a simple Reduce operation summing the tensors, we specify the operation to use with <code>op=dist.ReduceOp.SUM</code> (you can find more information on the supported operations in the docs: <a href="https://pytorch.org/docs/stable/distributed.html#torch.distributed.ReduceOp">https://pytorch.org/docs/stable/distributed.html#torch.distributed.ReduceOp</a>):</p> | |
| <p>```python | |
| def example_reduce(): | |
| tensor = torch.tensor([dist.get_rank() + 1] * 5, dtype=torch.float32).cuda() | |
| print(f"Before reduce on rank {dist.get_rank()}: {tensor}") | |
| dist.reduce(tensor, dst=0, op=dist.ReduceOp.SUM) | |
| print(f"After reduce on rank {rank}: {tensor}")</p> | |
| <p>init_process() | |
| example_reduce() | |
| ```</p> | |
| <p>Note that in the Reduce operation only the tensor on the <code>dst</code> node is updated:</p> | |
| <p>```python | |
| Before reduce on rank 0: tensor([1., 1., 1., 1., 1.], device='cuda:0') | |
| Before reduce on rank 1: tensor([2., 2., 2., 2., 2.], device='cuda:1') | |
| Before reduce on rank 2: tensor([3., 3., 3., 3., 3.], device='cuda:2')</p> | |
| <p>After reduce on rank 0: tensor([6., 6., 6., 6., 6.], device='cuda:0') | |
| After reduce on rank 1: tensor([2., 2., 2., 2., 2.], device='cuda:1') | |
| After reduce on rank 2: tensor([3., 3., 3., 3., 3.], device='cuda:2') | |
| ```</p> | |
| <p>Similarly we can perform an AllReduce (we don’t need to specify a destination in this case):</p> | |
| <p>```python | |
| def example_all_reduce(): | |
| tensor = torch.tensor([dist.get_rank() + 1] * 5, dtype=torch.float32).cuda() | |
| print(f"Before all_reduce on rank {dist.get_rank()}: {tensor}") | |
| dist.all_reduce(tensor, op=dist.ReduceOp.SUM) | |
| print(f"After all_reduce on rank {dist.get_rank()}: {tensor}")</p> | |
| <p>init_process() | |
| example_all_reduce() | |
| ```</p> | |
| <p>In this case the result is available on all nodes:</p> | |
| <p>```python | |
| Before all_reduce on rank 0: tensor([1., 1., 1., 1., 1.], device='cuda:0') | |
| Before all_reduce on rank 1: tensor([2., 2., 2., 2., 2.], device='cuda:1') | |
| Before all_reduce on rank 2: tensor([3., 3., 3., 3., 3.], device='cuda:2')</p> | |
| <p>After all_reduce on rank 0: tensor([6., 6., 6., 6., 6.], device='cuda:0') | |
| After all_reduce on rank 1: tensor([6., 6., 6., 6., 6.], device='cuda:1') | |
| After all_reduce on rank 2: tensor([6., 6., 6., 6., 6.], device='cuda:2') | |
| ```</p> | |
| <h3><strong>A quick focus on Ring All-Reduce</strong></h3> | |
| <p><strong>Ring All-Reduce</strong> is one specific implementation of All-Reduce, optimized for scalability. Rather than all devices communicating with each other directly, which could create communication bottlenecks, Ring All-Reduce can be broken down into two key steps: <strong>Reduce-Scatter</strong> and <strong>All-Gather</strong>. Here's how it works.</p> | |
| <p><strong>1. Reduce-Scatter:</strong></p> | |
| <ul> | |
| <li>Each device splits its data (e.g., gradients) into chunks and sends one chunk to its neighbor. Simultaneously, each device receives a chunk from its other neighbor.</li> | |
| <li>As each device receives a chunk, it adds (reduces) its corresponding chunk to the received one.</li> | |
| <li>This process continues around the ring until each device holds a partially reduced chunk, representing a sum of the gradients across all devices for that chunk.</li> | |
| </ul> | |
| <p><strong>2. All-Gather:</strong></p> | |
| <ul> | |
| <li>Now, each device needs to collect the fully reduced chunks from other devices.</li> | |
| <li>The devices start sending their reduced chunks to neighbors.</li> | |
| <li>Each device forwards the chunks it receives until every device has all the fully reduced chunks, giving each device the complete, summed-up gradient.</li> | |
| </ul> | |
| <p>Let’s illustrate this with the following gifs, where we have 5 GPUs, each with a tensor of length 5. The first GIF represents the reduce-scatter step, where, at the end, each GPU receives the reduced results for a specific chunk of data (orange rectangle). The second image represents the all-gather step, where, at the end, each GPU obtains the full results of the all-reduce operation.</p> | |
| <p><img alt=" Reduce-scatter" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/reduce_scatter.gif" /></p> | |
| <pre><code> Reduce-scatter | |
| </code></pre> | |
| <p><img alt=" All-Gather" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/all_gather.gif" /></p> | |
| <pre><code> All-Gather | |
| </code></pre> | |
| <p>You may have noticed that each of the $N$ GPUs sends and receives values $N-1$ times during both the reduce-scatter and all-gather steps. Each GPU sends $\frac{K}{N}$ values per transfer, where $K$ is the total number of values in the array being summed across the GPUs. Therefore, the total amount of data transferred to and from each GPU is $2×(N−1)× \frac{K}{N}$. When $𝑁$ (the number of GPUs) is large, the total amount of data transferred to and from each GPU is approximately $2×K$, where $𝐾$ is the total number of parameters. </p> | |
| <p><strong>There are two key things to keep in mind for all-reduce:</strong></p> | |
| <ol> | |
| <li>The communication cost for all-reduce is approximately $2×K$ when $N$ (the number of GPUs) is large.</li> | |
| <li>An all-reduce operation can be broken down into a reduce-scatter followed by an all-gather. The communication cost for these two operations is half that of the all-reduce, which is approximately $𝐾$.</li> | |
| </ol> | |
| <p>As we can see this implementation can make efficient use of even a limited bandwidth between nodes.</p> | |
| <p>Now let’s turn to our next distributed communication operation. In many real cases, each node individually perform many complex computations and we need to share the final results among nodes. Gather and AllGather are the operations we want to use in this case. Let’s take a look! </p> | |
| <h3>Gather & AllGather</h3> | |
| <p>Gather and AllGather are quite similar to the Broadcast in that they allow distributing data among node without modification. The main difference to Broadcast is that there is not one value we need to share from one node to all other nodes but each node has an individual chunk of data that we want to either gather all data on one node (in case of Gather) or gather all data on all nodes (in the case of AllGather). A picture being worth 1000 words, let’s take a look:</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2078.png" /></p> | |
| <p>Note that the dashed lines indicate that some data actually doesn’t move at all (since it’s already present on the node).</p> | |
| <p>In the case of the gather operation we need to prepare a container objects where the gathered tensors can be stored in this example the <code>gather_list</code>:</p> | |
| <p>```python | |
| def example_gather(): | |
| tensor = torch.tensor([dist.get_rank() + 1] * 5, dtype=torch.float32).cuda() | |
| if dist.get_rank() == 0: | |
| gather_list = [torch.zeros(5, dtype=torch.float32).cuda() for _ in range(dist.get_world_size())] | |
| else: | |
| gather_list = None | |
| print(f"Before gather on rank {dist.get_rank()}: {tensor}") | |
| dist.gather(tensor, gather_list, dst=0) | |
| if dist.get_rank() == 0: | |
| print(f"After gather on rank 0: {gather_list}")</p> | |
| <p>init_process() | |
| example_gather() | |
| ```</p> | |
| <p>And we see that the <code>gather_list</code> indeed contains the tensors of all ranks:</p> | |
| <p>```python | |
| Before gather on rank 0: tensor([1., 1., 1., 1., 1.], device='cuda:0') | |
| Before gather on rank 1: tensor([2., 2., 2., 2., 2.], device='cuda:1') | |
| Before gather on rank 2: tensor([3., 3., 3., 3., 3.], device='cuda:2')</p> | |
| <p>After gather on rank 0: [tensor([1., 1., 1., 1., 1.], device='cuda:0'), | |
| tensor([2., 2., 2., 2., 2.], device='cuda:0'), | |
| tensor([3., 3., 3., 3., 3.], device='cuda:0')] | |
| ```</p> | |
| <p>The only thing we need to change for the AllGather example is that every node will need a placeholder for the results:</p> | |
| <p>```python | |
| def example_all_gather(): | |
| tensor = torch.tensor([dist.get_rank() + 1] * 5, dtype=torch.float32).cuda() | |
| gather_list = [torch.zeros(5, dtype=torch.float32).cuda() for _ in range(dist.get_world_size())] | |
| print(f"Before all_gather on rank {dist.get_rank()}: {tensor}") | |
| dist.all_gather(gather_list, tensor) | |
| print(f"After all_gather on rank {dist.get_rank()}: {gather_list}")</p> | |
| <p>init_process() | |
| example_all_gather() | |
| ```</p> | |
| <p>And indeed we can see that now each node has all the data:</p> | |
| <p>```python | |
| Before all_gather on rank 0: tensor([1., 1., 1., 1., 1.], device='cuda:0') | |
| Before all_gather on rank 1: tensor([2., 2., 2., 2., 2.], device='cuda:1') | |
| Before all_gather on rank 2: tensor([3., 3., 3., 3., 3.], device='cuda:2')</p> | |
| <p>After all_gather on rank 0: [tensor([1., 1., 1., 1., 1.], device='cuda:0'), | |
| tensor([2., 2., 2., 2., 2.], device='cuda:0'), | |
| tensor([3., 3., 3., 3., 3.], device='cuda:0')] | |
| After all_gather on rank 1: [tensor([1., 1., 1., 1., 1.], device='cuda:1'), | |
| tensor([2., 2., 2., 2., 2.], device='cuda:1'), | |
| tensor([3., 3., 3., 3., 3.], device='cuda:1')] | |
| After all_gather on rank 2: [tensor([1., 1., 1., 1., 1.], device='cuda:2'), | |
| tensor([2., 2., 2., 2., 2.], device='cuda:2'), | |
| tensor([3., 3., 3., 3., 3.], device='cuda:2')] | |
| ```</p> | |
| <p>Now what about the inverse of a gather? In this case we would have all the data on one node and want to distribute/slice it among node, possibly with some intermediate processing? We can use the Scatter, or in the case of an operation in between a Reduce Scatter pattern:</p> | |
| <h3>Scatter & ReduceScatter</h3> | |
| <p>As the name subtly suggests, the goal of the Scatter operation is to take data on one node and distribute slices of it to all other nodes. It’s thus different from the Broadcast operation which copy data without slicing and it’s the logical the inverse of the Gather operation.</p> | |
| <p>The ReduceScatter pattern is slightly more complex: imagine you apply an operation like in the Reduce case but instead of moving the result to just one node we also distribute it evenly to all nodes:</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2079.png" /></p> | |
| <p>The Scatter operation is written in code as the opposite of the Gather: instead of preparing a list of tensors as target we prepare the source data as a list of tensors we want to distribute. We also need to specify the <code>src</code>:</p> | |
| <p>```python | |
| def example_scatter(): | |
| if dist.get_rank() == 0: | |
| scatter_list = [torch.tensor([i + 1] * 5, dtype=torch.float32).cuda() for i in range(dist.get_world_size())] | |
| print(f"Rank 0: Tensor to scatter: {scatter_list}") | |
| else: | |
| scatter_list = None | |
| tensor = torch.zeros(5, dtype=torch.float32).cuda() | |
| print(f"Before scatter on rank {dist.get_rank()}: {tensor}") | |
| dist.scatter(tensor, scatter_list, src=0) | |
| print(f"After scatter on rank {dist.get_rank()}: {tensor}")</p> | |
| <p>init_process() | |
| example_scatter() | |
| ```</p> | |
| <p>As a result we can see how the empty tensors got filled with the contents of the <code>scatter_list</code>:</p> | |
| <p>```python | |
| Rank 0: Tensor to scatter: [tensor([1., 1., 1., 1., 1.], device='cuda:0'), | |
| tensor([2., 2., 2., 2., 2.], device='cuda:0'), | |
| tensor([3., 3., 3., 3., 3.], device='cuda:0')] | |
| Before scatter on rank 0: tensor([0., 0., 0., 0., 0.], device='cuda:0') | |
| Before scatter on rank 1: tensor([0., 0., 0., 0., 0.], device='cuda:1') | |
| Before scatter on rank 2: tensor([0., 0., 0., 0., 0.], device='cuda:2')</p> | |
| <p>After scatter on rank 0: tensor([1., 1., 1., 1., 1.], device='cuda:0') | |
| After scatter on rank 1: tensor([2., 2., 2., 2., 2.], device='cuda:1') | |
| After scatter on rank 2: tensor([3., 3., 3., 3., 3.], device='cuda:2') | |
| ```</p> | |
| <p>Let’s create more interesting data to demonstrate the ReduceScatter logic: on each node we create a list of 2-elements vector on each node with a power exponent and an offset function of the node rank (it’s a bit hard to imagine so just look below for an example): </p> | |
| <p>```python | |
| def example_reduce_scatter(): | |
| rank = dist.get_rank() | |
| world_size = dist.get_world_size() | |
| input_tensor = [torch.tensor([(rank + 1) * i for i in range(1, 3)], dtype=torch.float32).cuda()**(j+1) for j in range(world_size)] | |
| output_tensor = torch.zeros(2, dtype=torch.float32).cuda() | |
| print(f"Before ReduceScatter on rank {rank}: {input_tensor}") | |
| dist.reduce_scatter(output_tensor, input_tensor, op=dist.ReduceOp.SUM) | |
| print(f"After ReduceScatter on rank {rank}: {output_tensor}") </p> | |
| <p>init_process() | |
| example_reduce_scatter() | |
| ```</p> | |
| <p>Let’s print the pattern of data that we created. We also immediately see the ReduceScatter pattern: the first rank received the sum of the first tensor from each node, and the second rank contains the sum of the second tensor on each node and so on:</p> | |
| <p>```python | |
| Before ReduceScatter on rank 0: [tensor([1., 2.], device='cuda:0'), | |
| tensor([1., 4.], device='cuda:0'), | |
| tensor([1., 8.], device='cuda:0')] | |
| Before ReduceScatter on rank 1: [tensor([2., 4.], device='cuda:1'), | |
| tensor([ 4., 16.], device='cuda:1'), | |
| tensor([ 8., 64.], device='cuda:1')] | |
| Before ReduceScatter on rank 2: [tensor([3., 6.], device='cuda:2'), | |
| tensor([ 9., 36.], device='cuda:2'), | |
| tensor([ 27., 216.], device='cuda:2')]</p> | |
| <p>After ReduceScatter on rank 0: tensor([ 6., 12.], device='cuda:0') | |
| After ReduceScatter on rank 1: tensor([14., 56.], device='cuda:1') | |
| After ReduceScatter on rank 2: tensor([ 36., 288.], device='cuda:2') | |
| ```</p> | |
| <p>We now have seen the main building block of distributed operations but before we see them in action let’s have a look at a special operation used for synchronization: the Barrier.</p> | |
| <h3>Barrier</h3> | |
| <p>The Barrier is a simple operation to synchronize all nodes. A barrier is not lifted until all nodes have reached it. Then only are they allowed to continue with further computations:</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2080.png" /></p> | |
| <p>We can easily simulate delayed nodes by setting up a different sleep time on each node and see how long it takes for all of them to pass the barrier:</p> | |
| <p>```python | |
| def example_barrier(): | |
| rank = dist.get_rank() | |
| t_start = time.time() | |
| print(f"Rank {rank} sleeps {rank} seconds.") | |
| time.sleep(rank) # Simulate different processing times | |
| dist.barrier() | |
| print(f"Rank {rank} after barrier time delta: {time.time()-t_start:.4f}")</p> | |
| <p>init_process() | |
| example_barrier() | |
| ```</p> | |
| <p>We can see that although the first rank didn’t sleep at all it also took it 2sec to pass the barrier:</p> | |
| <p>```python | |
| Rank 0 sleeps 0 seconds. | |
| Rank 1 sleeps 1 seconds. | |
| Rank 2 sleeps 2 seconds.</p> | |
| <p>Rank 0 after barrier time delta: 2.0025 | |
| Rank 1 after barrier time delta: 2.0025 | |
| Rank 2 after barrier time delta: 2.0024 | |
| ```</p> | |
| <p>We need to be careful with synchronizing all nodes like this, as this defeat the purpose of parallel independent operations and might thus slow down the whole processing. In many situations it can be just fine if a fast node already starts processing the next job as the fast node could be slower in a next iteration therefore evening out the delay over the whole process.</p> | |
| <p>Before turning to practical distributed training implementations, let’s first solve a mystery: what the heck is NCCL?</p> | |
| <h3>NCCL: NVIDIA Collective Communications Library</h3> | |
| <p>When training large models on many GPUs we may sometimes strike gold but we will always encounter nickel (or NCCL)! What’s is that?</p> | |
| <p>There are several libraries that implement collective communication and are support by PyTorch: there’s the classic <code>MPI</code> (Message Passing Interface), there’s <code>Gloo</code> by Meta, and finally there is <code>NCCL</code> (NVIDIA Collective Communications Library). They all provide similar functionality in terms of collective communication patterns but are optimized for different hardware setups; NCCL is designed to serve GPU-GPU communication efficiently while MPI and Gloo are setup for CPU-CPU or CPU-GPU communication. PyTorch provides a <a href="https://pytorch.org/docs/stable/distributed.html#which-backend-to-use">great guide</a> to decide which one to use: </p> | |
| <ul> | |
| <li>GPU training: use NCCL</li> | |
| <li>CPU training: use Gloo</li> | |
| </ul> | |
| <p>There are a few finer points in the decision tree that we leave to the reader to explore in the PyTorch guide referenced above.</p> | |
| <p>Now that we covered the fundamental operations for distributed training and when you should use them let’s turn to practical implementation.</p> | |
| <h2>A1: Profiling</h2> | |
| <h3>Kernels</h3> | |
| <p>Let’s begin by assuming for now that the kernels are already integrated into PyTorch. As a simple example, we can look at the Layer Normalization function implemented in PyTorch as <code>torch.nn.functional.layer_norm</code>. There are several methods to profile the kernel that underlies this function. The most straightforward approach might be to use the Python <code>time</code> module. However, since CUDA operations are asynchronous, measuring time with this method will only capture the overhead associated with launching the kernel in Python, rather than the actual execution time of the kernel itself.</p> | |
| <p>To address this, we can utilize <code>torch.cuda.Event</code> for accurate timing and employ the <code>torch.cuda.synchronize()</code> directive to ensure we wait for the kernel execution to complete. This approach is demonstrated in the following snippet : </p> | |
| <p>```python | |
| def profile_pytorch(func, input): | |
| # Create CUDA events to track time. CUDA operations are asynchronous, | |
| start = torch.cuda.Event(enable_timing=True) # Event to mark the start time | |
| end = torch.cuda.Event(enable_timing=True) # Event to mark the end time | |
| # Warmup to eliminate any overhead from the first run, which might not reflect | |
| # the actual performance. | |
| for _ in range(10): | |
| func(input) | |
| # Record the start time before executing the function | |
| start.record()<br /> | |
| func(input) # Call the function we want to profile | |
| # Record the end time after the function has completed | |
| end.record()<br /> | |
| # Synchronize the CUDA operations to ensure all operations are completed | |
| # before measuring the elapsed time. | |
| torch.cuda.synchronize()<br /> | |
| # Calculate and return the elapsed time in milliseconds. | |
| return start.elapsed_time(end) </p> | |
| <p>```</p> | |
| <p>A more effective approach to profiling is to utilize the PyTorch Profiler, as explained previously. For example, consider the following code:</p> | |
| <p>```python | |
| import torch | |
| import torch.nn.functional as F</p> | |
| <p>def pytorch_layer_norm(input): | |
| return F.layer_norm(input, input.size()[1:])</p> | |
| <p>a = torch.randn(10000, 10000).cuda()</p> | |
| <p>with torch.profiler.profile( | |
| activities=[ | |
| torch.profiler.ProfilerActivity.CPU, # Profile CPU activities | |
| torch.profiler.ProfilerActivity.CUDA, # Profile CUDA activities | |
| ], | |
| # Define a schedule for the profiler | |
| schedule=torch.profiler.schedule( | |
| wait=1, # Wait for 1 iteration before starting to profile | |
| warmup=3, # Warm up for 3 iterations to stabilize performance | |
| active=2, # Profile for 2 active iterations | |
| repeat=1, # Repeat the profiling schedule once | |
| ), | |
| on_trace_ready=torch.profiler.tensorboard_trace_handler('.'),</p> | |
| <p>) as p: | |
| for iter in range(10): | |
| pytorch_layer_norm(a) | |
| p.step()</p> | |
| <h1>Print a table of the profiling results, sorted by total CUDA time, limited to the top 10 entries</h1> | |
| <p>print(p.key_averages().table(sort_by="cuda_time_total", row_limit=8))</p> | |
| <p>```</p> | |
| <p>This would print aggregated profiling results sorted by the total CUDA time, and the output would be:</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2081.png" /></p> | |
| <p>You can also try to inspect the trace as we previously mentioned on <code>chrome://tracing/</code> </p> | |
| <blockquote> | |
| <p>If you're new to this tool, you can navigate the trace by using the right and left arrow keys. Additionally, you can zoom in and out by holding the <strong>Alt</strong> key while scrolling left or right with your mouse. | |
| </p> | |
| </blockquote> | |
| <p>After zooming in, you can observe the flow of operations when calling <code>layer_norm</code> in this trace:</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2082.png" /></p> | |
| <p>The sequence begins in the CPU (the upper section) with <code>aten::layer_norm</code>, progressing to <code>aten::native_layer_norm</code>, and then transitioning to <code>cudaLaunchKernel</code>. From there, we move on to the GPU, where the <code>vectorized_layer_norm_kernel</code> kernel is called. </p> | |
| <blockquote> | |
| <p>Note that you can enable memory profiling by setting <code>profile_memory</code> to <code>True</code> in the profiler. However, this can lead to more complex traces. | |
| </p> | |
| </blockquote> | |
| <p>While the PyTorch Profiler offers a quick performance overview, <strong>NVIDIA Nsight Compute (ncu)</strong> provides deeper insights into GPU performance, including detailed execution times and memory usage for each kernel. To run the profiler it’s very simple : </p> | |
| <p><code>python | |
| ncu --set full python layer_norm.py</code></p> | |
| <p>Where <code>layer_norm.py</code> is a straightforward file that executes the layer normalization function. This command will generate log outputs, but a more effective way to visualize the results is by setting the output flag:</p> | |
| <p><code>python | |
| ncu --set full -o output python layer_norm.py</code></p> | |
| <p>and open the file <code>output.ncu-rep</code> with Nsight Compute, you will have a view that looks like this : </p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2083.png" /></p> | |
| <p>With clear warnings about compute and memory utilization, and how to make the kernel better in balancing compute and memory and achieve maximal occupancy.</p> | |
| <p><strong>CPP extension</strong></p> | |
| <p>If the kernel you want to profile isn’t already integrated into PyTorch, you can use PyTorch's <code>cpp_extension</code> module to easily compile and run custom CUDA code. The process is straightforward—just create your CUDA kernel in a <code>.cu</code> file, and use the <code>load</code> function from the <code>cpp_extension</code> module to load it in Python.</p> | |
| <p>The <code>.cu</code> file would like this for a simple <code>add</code> kernel : </p> | |
| <p>```cpp</p> | |
| <h1>include <torch/extension.h></h1> | |
| <h1>include <cuda.h></h1> | |
| <h1>include <cuda_runtime.h></h1> | |
| <p><strong>global</strong> void add_kernel(float<em> x, float</em> y, float* output, int size) { | |
| int index = blockIdx.x * blockDim.x + threadIdx.x; | |
| if (index < size) { | |
| output[index] = x[index] + y[index]; | |
| } | |
| }</p> | |
| <p>void add_cuda(torch::Tensor x, torch::Tensor y, torch::Tensor output) { | |
| int threads = 1024; | |
| int blocks = (x.size(0) + threads - 1) / threads;</p> | |
| <pre><code>add_kernel<<<blocks, threads>>>(x.data_ptr<float>(), y.data_ptr<float>(), output.data_ptr<float>(), x.size(0)); | |
| </code></pre> | |
| <p>} | |
| PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |
| m.def("add_cuda", &add_cuda, "Vector addition (CUDA)"); | |
| } | |
| ```</p> | |
| <p>And the python file to load the kernel : </p> | |
| <p>```python | |
| import torch | |
| from torch.utils.cpp_extension import load</p> | |
| <h1>Load and compile the CUDA extension</h1> | |
| <p>vector_add = load( | |
| name="vector_add", | |
| sources=["add_kernel.cu"], | |
| verbose=True | |
| )</p> | |
| <h1>Define input tensors</h1> | |
| <p>size = 10000 | |
| x = torch.randn(size, device='cuda') | |
| y = torch.randn(size, device='cuda') | |
| output = torch.empty(size, device='cuda')</p> | |
| <h1>Run the CUDA kernel</h1> | |
| <p>vector_add.add_cuda(x, y, output) | |
| ```</p> | |
| <p>Using this method, you can profile the custom CUDA kernel just as we demonstrated earlier with PyTorch's profiler or NVIDIA tools.</p> | |
| <h2>A2: TP Backward pass</h2> | |
| <p>And what happen during the backward pass? Usually PyTorch is taking care of it for us but relying on <code>torch.autograd</code> for automatic differentiation is not always optimal and we might need more precise understanding/control over what is happening in the backward pass. Let’s dive a bit to compute the gradients which respect to our two important variables:</p> | |
| <ul> | |
| <li>Gradient w.r.t inputs</li> | |
| <li>Gradient w.r.t to weight</li> | |
| </ul> | |
| <p>[TODO: link to the reminder of the chain rule computation when explaining forward activation memory usage ?]</p> | |
| <p>We can break this down in a simple example. Remember that a linear layer is doing (omitting bias for simplicity)</p> | |
| <p>$$ | |
| Y = XW | |
| $$</p> | |
| <p>Where, Y is the output, X is the input and W is the weight matrix. If you need to compute the gradient w.r.t to the input, one can use the chain rule for it:</p> | |
| <p>$$ | |
| \frac{dL}{dX} = \frac{dL}{dY} \frac{dY}{dX} = \frac{dL}{dY} W</p> | |
| <p>$$</p> | |
| <p>The chain rule applies here since the loss (L) depends directly on the output (Y). This equation is telling us that to get the gradient of the loss with respect to our input (dL/dX), we multiply the gradient of the loss with respect to the output (dL/dY) by our weight matrix (W).</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2084.png" /></p> | |
| <p>Likewise, we can use chain rule to compute the gradient w.r.t to the weight:</p> | |
| <p>$$ | |
| \frac{dL}{dW} = \frac{dL}{dY} \frac{dY}{dW} = \frac{dL}{dY} X | |
| $$</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2085.png" /></p> | |
| <p>Here is a snippet of code to clarify all the concepts above:</p> | |
| <p>```python | |
| def column_linear_forward(X, local_W, group): | |
| Y_local = X @ local_W.t() | |
| return Y_local</p> | |
| <p>def column_linear_backward(local_grad_Y, X, local_W, group): | |
| local_grad_X = local_grad_Y @ local_W | |
| grad_W = local_grad_Y.t() @ X <br /> | |
| return local_grad_X, grad_W</p> | |
| <p>def row_linear_forward(local_X, local_W, group): | |
| Y_local = local_X @ local_W.t() | |
| dist.all_reduce(Y_local, group=group) | |
| Y = Y_local | |
| return Y</p> | |
| <p>def row_linear_backward(grad_Y, X, local_W, group): | |
| local_grad_X = grad_Y @ local_W | |
| grad_W = grad_Y.t() @ X | |
| return local_grad_X, grad_W</p> | |
| <p>def example_column_row_linear(): | |
| # torchrun --nproc_per_node=2 tp_all_reduce.py | |
| group = dist.distributed_c10d._get_default_group()</p> | |
| <pre><code>X_ref = torch.arange(4 * 2, device="cuda", dtype=torch.float32, requires_grad=True).reshape(4, 2) | |
| W_ref_layer1 = torch.arange(1, 5, device="cuda", dtype=torch.float32, requires_grad=True).reshape(2, 2) * 10 | |
| W_ref_layer2 = torch.arange(1, 5, device="cuda", dtype=torch.float32, requires_grad=True).reshape(2, 2) | |
| X_ref.retain_grad() | |
| W_ref_layer1.retain_grad() | |
| W_ref_layer2.retain_grad() | |
| dist.broadcast(X_ref, src=0, group=group) | |
| dist.broadcast(W_ref_layer1, src=0, group=group) | |
| dist.broadcast(W_ref_layer2, src=0, group=group) | |
| X = X_ref.clone() | |
| W_layer1 = W_ref_layer1.clone() | |
| W_layer2 = W_ref_layer2.clone() | |
| # Forward | |
| Y_ref_linear1 = X_ref @ W_ref_layer1.t() | |
| Y_ref_linear1.retain_grad() | |
| # We will transpose for matrix multiplication. As a result, we need to split row-wise | |
| Y_local_linear1 = column_linear_forward(X, split_tensor(W_layer1, dim=0), group) | |
| torch.testing.assert_close(Y_local_linear1, split_tensor(Y_ref_linear1, dim=1), rtol=1e-5, atol=1e-5) | |
| Y_local_linear2 = row_linear_forward(Y_local_linear1, split_tensor(W_ref_layer2, dim=1), group) | |
| Y_ref_linear2 = Y_ref_linear1 @ W_ref_layer2.t() | |
| torch.testing.assert_close(Y_local_linear2, Y_ref_linear2, rtol=1e-5, atol=1e-5) | |
| # Backward | |
| Y_ref_linear2.sum().backward() | |
| grad_Y = torch.ones_like(Y_ref_linear2) | |
| grad_X_linear2, grad_W_linear2 = row_linear_backward(grad_Y, Y_local_linear1, split_tensor(W_layer2, dim=1), group) | |
| torch.testing.assert_close(grad_X_linear2, split_tensor(Y_ref_linear1.grad, dim=1), rtol=1e-5, atol=1e-5) | |
| torch.testing.assert_close(grad_W_linear2, split_tensor(W_ref_layer2.grad, dim=1), rtol=1e-5, atol=1e-5) | |
| grad_X, grad_W = column_linear_backward(grad_X_linear2, X, split_tensor(W_layer1, dim=0), group) | |
| torch.testing.assert_close(grad_X, X_ref.grad, rtol=1e-5, atol=1e-5) | |
| torch.testing.assert_close(grad_W, split_tensor(W_ref_layer1.grad, dim=0), rtol=1e-5, atol=1e-5) | |
| </code></pre> | |
| <p>if <strong>name</strong> == "<strong>main</strong>": | |
| dist.init_process_group("nccl", rank=int(os.environ["RANK"]), world_size=int(os.environ["WORLD_SIZE"])) | |
| torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))</p> | |
| <pre><code>example_column_row_linear() | |
| </code></pre> | |
| <p>```</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2086.png" /></p> | |
| <p><strong>TODO</strong> add these illustrations somewhere? I found them helpful:</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2087.png" /></p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2088.png" /></p> | |
| <h2>A3: ZeRO-R</h2> | |
| <p>To further optimize memory usage in large-scale training, DeepSpeed ZeRO-R introduces several techniques aimed at reducing the memory footprint of activation values during forward and backward propagation. The key strategies include partitioned activation checkpointing, the use of constant-size buffers, and memory defragmentation.</p> | |
| <h3>$P_a:$ Partitioned Activation Checkpointing</h3> | |
| <p>In Tensor Parallelism, the activation value is replicated across different GPUs. ZeRO partitions the activations across different GPUs, and only materializes them in a replicated form one activation layer at a time during the backward, right before the activation is used in computation, through an all-gather operation. With partitioned activation checkpointing, ZeRO reduces the activation footprint by a factor proportional to the TP degree.</p> | |
| <h3><strong>$C_B:$ Constant Size Buffers</strong></h3> | |
| <p>The larger the data sent in each all-reduce operation, the higher the bandwidth. Megatron-LM combines all data into a single buffer, which can become too large, reaching up to 12GB. By using a fixed-size buffer, Deepspeed achieves sufficient bandwidth without running into out-of-memory (OOM) issues.</p> | |
| <h3><strong>$M_D$: Memory Defragmentation</strong></h3> | |
| <p>During forward propagation with activation checkpointing, only selected activations are retained for backpropagation, while the majority are discarded since they can be recomputed later. This causes a mix of short-lived memory (discarded activations) and long-lived memory (checkpointed activations), which can lead to memory fragmentation. In backward propagation, parameter gradients are long-lived, whereas activation gradients and other buffers needed for gradient computation are short-lived. To reduce memory fragmentation, DeepSpeed pre-allocates contiguous memory and transfers long-lived parameters into this pre-allocated space.</p> | |
| <h3>Communication Analysis of ZeRO-R</h3> | |
| <p><strong>Communication Overhead in Parallelism Strategies</strong></p> | |
| <p>When using the Pa strategy, the increase in communication volume is generally less than one-tenth of that in the baseline Tensor Parallelism (also known as Model Parallelism).</p> | |
| <p><strong>Communication Overhead of Megatron-LM with Tensor Parallelism (TP):</strong></p> | |
| <p>In Megatron-LM, when using activation checkpointing, only the input activations for the transformer blocks are stored, while the rest are discarded to save memory. This approach requires an additional forward computation during the backward pass to recompute the activation values. For each transformer block, there are two all-reduce operations of $\text{batch}_\text{size} \times \text{seq_length} \times \text{hidden_dim}$ during the forward pass, two more during re-computation, and another two during the backward pass. In total six all-reduce operations, leading to a total communication cost of $12 \times \text{batch}_\text{size} \times \text{seq_length} \times \text{hidden_dim}$.</p> | |
| <p><strong>Communication Overhead of $P_a$ with ZeRO-R:</strong></p> | |
| <p>ZeRO-R with $P_a$ only saves the input activation for each transformer block, which means it only needs to perform an all-gather operation on this input. The communication volume is reduced to $\text{batch}_\text{size} \times \text{seq_length} \times \text{hidden_dim}$. As a result, the total communication volume with $P_a$ is approximately 1/12 of that in Megatron's Tensor Parallelism, which is less than 10% of the original volume. However, $P_a$ can be used to reduce the data-parallel communication volume by the Tensor Parallelism (TP) degree, even if it slightly increases tensor-parallel communication volume by around 10%. This trade-off can significantly boost efficiency, especially when data-parallel communication becomes a performance bottleneck.</p> | |
| <h2>A5. Memory profile</h2> | |
| <p>You can quickly profile the memory usage of your model with the <code>torch.profiler.profile</code>: it records all the operations and variables that are used in the context and classifies them into categories as well. Here’s a small example how to profile Llama-3.2-1B with a minimal training loop:</p> | |
| <p>```python | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch</p> | |
| <p>def trace_handler(prof: torch.profiler.profile): | |
| # Construct the trace file. | |
| prof.export_chrome_trace("profile.json.gz") | |
| # Construct the memory timeline file. | |
| prof.export_memory_timeline("profile.html", device=device)</p> | |
| <p>model_name = "meta-llama/Llama-3.2-1B" | |
| device = "cuda:0" | |
| num_epochs = 4</p> | |
| <p>model = AutoModelForCausalLM.from_pretrained(model_name).to(device) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| input_texts = ["Large models require a lot of GPU memory!\n" * 128] * 8 | |
| input_ids = tokenizer(input_texts, return_tensors="pt").input_ids.to(device)</p> | |
| <h1>Set up optimizer</h1> | |
| <p>optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)</p> | |
| <p>with torch.profiler.profile( | |
| activities=[ | |
| torch.profiler.ProfilerActivity.CPU, | |
| torch.profiler.ProfilerActivity.CUDA, | |
| ], | |
| schedule=torch.profiler.schedule(wait=0, warmup=0, active=6, repeat=1), | |
| record_shapes=True, | |
| profile_memory=True, | |
| with_stack=True, | |
| on_trace_ready=trace_handler, | |
| ) as prof:</p> | |
| <pre><code># Training loop | |
| for epoch in range(num_epochs): | |
| # Forward pass | |
| loss = model(input_ids, labels=input_ids).loss | |
| # Backward pass | |
| loss.backward() | |
| # Update weights | |
| optimizer.step() | |
| # Reset gradients | |
| optimizer.zero_grad() | |
| profiler.step() | |
| </code></pre> | |
| <p>```</p> | |
| <h2>TP: Practical PyTorch Implementation</h2> | |
| <p>To implement our own Column Parallelism in PyTorch, we need to focus on two key aspects: forward and backward passes. This involves defining a custom autograd function that manages the specific operations required for column parallelism. </p> | |
| <ul> | |
| <li>In the forward pass, the linear layer computation is already available. Normally, a copy(broadcast) operation is needed to ensure that data is aligned across GPUs. However, in practice, since the data is already distributed to GPUs at the beginning of each step, you can skip this copy step</li> | |
| <li>In the backward pass, the key operation for column parallelism is the <strong>all-reduce</strong> step. This ensures that the gradients computed on each GPU are aggregated across all GPUs</li> | |
| </ul> | |
| <p>```python | |
| class <em>CopyToModelParallelRegion(torch.autograd.Function): | |
| """copy(identity) in forward pass, all reduce in backward pass""" | |
| @staticmethod | |
| def forward(ctx, input</em>): | |
| return input_</p> | |
| <pre><code>@staticmethod | |
| def backward(ctx, grad_output): | |
| return _reduce(grad_output) | |
| </code></pre> | |
| <h1>This is the <code>f</code> function in the paper: https://arxiv.org/abs/1909.08053</h1> | |
| <p>def copy_to_model_parallel_region(input_: torch.Tensor) -> torch.Tensor: | |
| return <em>CopyToModelParallelRegion.apply(input</em>)</p> | |
| <h1>core logic of Column Parallel linear</h1> | |
| <p>def linear_with_all_reduce(input_, weight, bias): | |
| input_parallel = copy_to_model_parallel_region(input_) | |
| output = F.linear(input_parallel, weight, bias) # XW_i^T + b, output is Y_i | |
| return output | |
| ```</p> | |
| <h3>Gelu code</h3> | |
| <p>If you rather like code, let’s explore this behavior with the following snippet:</p> | |
| <p>```python | |
| def example_gelu(): | |
| from torch.nn.functional import gelu</p> | |
| <pre><code>X = torch.randn(4, 2, device="cuda", dtype=torch.float32) | |
| W = torch.randn(2, 2, device="cuda", dtype=torch.float32) | |
| W_0, W_1 = W.chunk(2, dim=1) | |
| # Column linear | |
| y_col_1 = torch.cat([gelu(X @ W_0), gelu(X @ W_1)], dim=1) | |
| y_col_2 = gelu(torch.cat([X @ W_0, X @ W_1], dim=1)) | |
| # All match | |
| torch.testing.assert_close(y_col_1, y_col_2, rtol=1e-5, atol=1e-5) | |
| # Row linear | |
| X_0, X_1 = X.chunk(2, dim=1) | |
| W_0, W_1 = W.chunk(2, dim=0) | |
| y_row_1 = gelu(X_0 @ W_0) + gelu(X_1 @ W_1) | |
| y_row_2 = gelu(X_0 @ W_0 + X_1 @ W_1) | |
| # Mismatch | |
| torch.testing.assert_close(y_row_1, y_row_2, rtol=1e-5, atol=1e-5) | |
| </code></pre> | |
| <p>```</p> | |
| <h3>Interconnect</h3> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2089.png" /></p> | |
| <h2>How to profile your code</h2> | |
| <p>The profiler is a tremendously useful tool and easy to use. It takes three steps to profile your program:</p> | |
| <ol> | |
| <li>Create a profiler (a context manager)</li> | |
| <li>Wrap your training code using <code>with profiler</code></li> | |
| <li>Perform a <code>profiler.step()</code> during the training</li> | |
| </ol> | |
| <p>```python | |
| profiler = torch.profiler.profile( # step 1. Create your profiler | |
| activities=[ | |
| torch.profiler.ProfilerActivity.CPU, | |
| torch.profiler.ProfilerActivity.CUDA, | |
| ], | |
| schedule=torch.profiler.schedule(wait=10, warmup=10, active=5, repeat=1), | |
| on_trace_ready=torch.profiler.tensorboard_trace_handler(profiler_output_dir), | |
| record_shapes=True, | |
| profile_memory=True, | |
| with_flops=True, | |
| with_modules=True, | |
| with_stack=True | |
| ) | |
| with profiler: # step 2. Wrap the training with profiler | |
| for data in dataloader: | |
| profiler.step() # step 3. profiler.step()</p> | |
| <pre><code> input_ids, label_ids = data['input_ids'].to(device), data['label_ids'].to(device) | |
| # Forward pass | |
| outputs = model(input_ids) | |
| </code></pre> | |
| <p>```</p> | |
| <p>After running this code, you will find <code>*.trace.json</code> files under the <code>profiler_out_dir</code>. To visualize the results, the easiest way is to open Google Chrome, go to <code>chrome://tracing/</code>, and drag the file into it. This will allow you to view the profiling results. To get more details, we invite you to check out the amazing <a href="https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html">**tutorial</a>** created by PyTorch.</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2090.png" /></p> | |
| <h2>Formulas for compute / comms the balanhe balance</h2> | |
| <p>```markdown</p> | |
| <p>Estimates: (number of elements, need to multiply by bytes_per_element) | |
| activs: hidden_states : seq * bs * h | |
| model = grads: weight : h * h | |
| optimstates : 2 * model = 2 * h * h | |
| transformer block : 16h^2 ( = attn + mlp = 3h^2(qkv) + h^2(out_proj) + 8h^2(gate_up) + 4h^2(gate_down) ) (if glu else 12h^2) | |
| total_model_params ~= 16h^2 * num_layers (if glu else 12h^2 * num_layers) (missing embeds + lmhead terms = vocab*hidden_size if no pos_embeds) | |
| C_fwd ~ 2num_tokens * num_params | |
| C = C_fwd + C_bwd = 6num_tokens * num_params = 6 * mbs * seq * num_layers * 16h^2</p> | |
| <p>Comms payload size: | |
| * DDP: (allreduce grads_bf16 bwd) | |
| grads = params = num_layers *16h^2 = num_layers * 1.6B = num_layers * 3.2GB(bf16 precision) = 32GB (layers=10) | |
| -> backward will communicate 32GB (total) in chunks of ddp_bucket_cap_mb=25MB by default when bucketing</p> | |
| <p>-> payload = ddp_bucket_cap_mb | |
| -> total_comms_in_rank_per_step = total_model_params = num_layers *16h^2</p> | |
| <p>can be overlapped with backward pass | |
| -> t_compute = C_bwd / peak_bandwidth = 4 * num_tokens * num_params * peak_flops = 4 * mbs * seq * num_layers * 16h^2 / peak_flops | |
| peak_bandwidth_allreduce = (S / t) * 2(DP-1)/DP | |
| -> t_comm_bucket = t_idle = ddp_bucket_cap_mb * 2(DP-1)/ DP*peak_bandwidth_allreduce</p> | |
| <p>For the overlap: t_comm_overlap = (N-1) * t_comm_bucket where N is the number of buckets and can be overlapped with backward pass starting from after ddp_bucket_cap_mb was generated (N-1)t_comp_bucket=(N-1)t_comp/N | |
| N = num_params / ddp_bucket_cap_mb | |
| -> t_comm / t_compute = (N-1) t_comm_bucket / ((N-1)t_comp/N) = N t_comm_bucket / t_comp | |
| = N (ddp_bucket_cap_mb * 2(DP-1)/ DP*peak_bandwidth_allreduce) / (4 * num_tokens * num_params / peak_flops) | |
| = N (ddp_bucket_cap_mb * (DP-1) * peak_flops) / (2 * num_tokens * num_params * peak_bandwidth_allreduce * DP) | |
| = N (ddp_bucket_cap_mb / (num_tokens * num_params)) * (DP-1) * peak_flops / (2 * peak_bw * DP) | |
| = (num_params/num_tokens) * (DP-1) * peak_flops / (2 * peak_bw * DP) must be < 1</p> | |
| <ul> | |
| <li>FSDP: (allgather params_fp32 fwd, allgather params_fp32 bwd, reducescatter grads_fp32 bwd (cz data is sharded)) -> happens on each layer or several layers<blockquote> | |
| <p>I think we can choose to allgather either params_fp32 or params_fp16 depending on whether we want to reduce mem usage and increase comms or vice versa -> we still need to do fwd/bwd in bf16 and sharded optim step still updates fp32 | |
| in a training step, assuming we FullyShardedDataParallel to every block then one comm grp for each block and one comm grp for embeddings layer and lm head | |
| each comm group = single allgather and single reducescatter | |
| Transformer block = 1.6B -> h = (1.6<em>10</em><em>9/16)^0.5 = 10000 (closest is h=8k so llama70B) | |
| if 8 GPUS, GBS=8: | |
| transformer block per dp rank = 1.6B / 8 = 0.2B = 0.8GB (assuming 4 bytes per element) | |
| grads in a transformer block = params in a transformer block | |
| -> forward will communicate in chunks of 0.8GB (allgather) (8 chunks of 0.8GB circulating in network) | |
| -> backward will communicate 2 times 0.8GB (allgather params and reducescatter grads) (grads in a transformer block = params in a transformer block) | |
| so 3 comms of 0.8GB each. if num_layers = 10 -> 3</em>10 comms of 0.8GB each = 24GB (total) | |
| -> commsize = 24GB * num_gpus | |
| -> payload = layer_size / DP = transformer_block_size / DP = 16h**2 / DP (0.8GB if 4bytes) | |
| -> total_comms_in_rank_per_step = 3(fwd+2bwd) * num_layers * layer_size / DP = 3 * num_layers * 16h^2 / DP = 48 * num_layers * h^2/DP (24GB)</p> | |
| </blockquote> | |
| </li> | |
| </ul> | |
| <p>can be overlapped with the next layer'sforward | |
| while i forward first layer, I'm allgathering next layer's params | |
| C_fwd = 2<em>seq</em>mbs*(16h^2) / DP = 32 * seq * mbs * h^2 / DP | |
| -> t_compute = C / peak_bandwidth = 32 * seq * mbs * h^2 / DP * peak_flops | |
| peak_bandwidth_allgather = (S / t) * (DP-1)/DP | |
| -> t_comm = 16 h^2 * peak_bw_allgather * (DP-1)/DP </p> | |
| <p>-> t_comm / t_compute = (DP-1) peak_flops / 2 * seq * mbs * peak_bw</p> | |
| <ul> | |
| <li>TP: (allgather activs_bf16 fwd, reducescatter activs_bf16 bwd) -> happens on each linear | |
| activs = seq * bs * h | |
| if h = 10k, seq=4000 , mbs=8 -> activs = seq * mbs * 10k = 4000 * 8 * 10k = 320M | |
| activs per TP rank = 320M / 8 = 40M = 80MB (bf16 precision) | |
| -> forward per activs will communicate in chunks of activs/TP = 80MB (allgather) (8 chunks of 80MB circulating in network) | |
| -> backward per grads will communicate in chunks of grads/TP = 80MB (reducescatter) (8 chunks of 80MB circulating in network) | |
| so for a transformer block: we have 4 linears | |
| -> forward will communicate 4 times 80MB (allgather) = 320MB (total) | |
| -> backward will communicate 4 times 80MB (reducescatter) = 320MB (total) | |
| and for 10 layers: | |
| -> forward will communicate 10 * 320MB = 3.2GB (total) | |
| -> backward will communicate 10 * 320MB = 3.2GB (total)</li> | |
| </ul> | |
| <p>-> payload = activs/TP = seq * mbs * h / TP (80MB if 2bytes) | |
| -> total_comms_in_rank_per_step = 2(fwd+bwd) * num_layers * 4 (linears/layer) * activs/TP = 8 * num_layers * seq * mbs * h/TP (6.4GB) | |
| https://www.determined.ai/blog/tp?t | |
| can be overlapped with next linear (bcz we allgather, do linear)</p> | |
| <p>for | |
| C_fwd = 2 * seq* mbs * (8h^2 + 4h^2)/TP = 24 seq * mbs * h^2 / TP | |
| -> t_compute = C / peak_bandwidth = 24 * seq * mbs * h^2 / TP * peak_flops</p> | |
| <p>peak_bandwidth_allgather = (S / t) * (TP-1)/TP | |
| -> t_comm = seq * mbs * h * (TP-1) / TP * peak_bw_allgather</p> | |
| <p>-> t_comm / t_compute = seq * mbs * h * peak_flops * (TP-1) / (24 * seq * mbs * h^2) * peak_bw = peak_flops (TP-1) / 24*h peak_bw</p> | |
| <ul> | |
| <li>PP: (recv activs_bf16 and send activs_bf16 in fwd, recv grads_bf16 and send grads_bf16 in bwd) -> happens each gradient accumulation step (gas)<blockquote> | |
| <p>grads_bf16 is not model's grads. It's an intermediate grads that is created when we do backward on activs_bf16. | |
| activs = seq * bs * h | |
| if h = 10k, seq=4000 , mbs=1 -> activs = seq * mbs * 10k = 4000 * 1 * 10k = 40M | |
| a PP rank in a training step will see (gas) microbatches | |
| for a single microbatch: | |
| -> forward per microbatch will communicate 2 times (recv and send) activs = 2 * 40M = 80M | |
| -> backward per microbatch will communicate 2 times (recv and send) grads = 2 * 40M = 80M | |
| so for (gas=8) microbatches: | |
| -> forward will communicate (gas) * 80M = 80M * 8 = 640M | |
| -> backward will communicate (gas) * 80M = 80M * 8 = 640M</p> | |
| </blockquote> | |
| </li> | |
| </ul> | |
| <p>-> payload = activs = seq * mbs * h (40M) | |
| -> total_comms_in_rank_per_step = 2(fwd+bwd) * gas * 2(recv+send) * activs = 4 * gas * seq * mbs * h (1.28GB) TODO: pick gas to have same gbs as before</p> | |
| <p>can be overlapped with next microbatch's forward / backward | |
| C_fwd = num_layers_in_next_pp * C_fwd_layer = num_layers_in_next_pp * 2<em>seq</em>mbs*(16h^2) = num_layers_in_next_pp * 32 * seq * mbs * h^2 | |
| -> t_compute = C / peak_bandwidth = num_layers_in_next_pp * 32 * seq * mbs * h^2 / peak_flops | |
| peak_bandwidth_p2p = (S / t) | |
| -> t_comm = seq * mbs * h / peak_bw_p2p</p> | |
| <p>for a single microbatch: | |
| -> t_comm / t_compute = seq * mbs * h * peak_flops / (num_layers_in_next_pp * 32 * seq * mbs * h^2) * peak_bw = peak_flops / (32 * h*num_layers_in_next_pp) * peak_bw</p> | |
| <p>```</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2091.png" /></p> | |
| <h2>Integrating Context Parallelism with TP/SP</h2> | |
| <p>We’ve seen that both TP/SP and CP shard the activations along sequence dimension, and both require communications in the Attention module, wouldn’t that create issues? In fact not at all!</p> | |
| <p>In order to integrate CP with TP/SP we just have to:</p> | |
| <ol> | |
| <li><strong>Split the input sequence across the context parallel group:</strong> At the beginning of the forward pass, we split the input sequence and dispatch it to different ranks within the context parallel process group. This approach reduces activation memory usage by a factor equal to the size of the context parallel process group. (CP0 gets seq_chunk0 and CP1 gets seq_chunk1)</li> | |
| <li><strong>Entering the attention module:</strong> Since each TP rank only has the QKV heads it’s responsible for, we all-gather the sequence length which was sharded because of sequence parallelism which should be overlapped with the QKV projection. ****</li> | |
| <li><strong>Replace standard attention with ring attention:</strong> During the forward pass, each TP rank relies on the ring attention to compute the correct attention results during both the forward and backward passes. So all CP ranks within TP=0 for example need to all-gather the full KV sequence and calculate attention, but we store only the KV of a sequence chunk to reduce memory activations by CP.</li> | |
| </ol> | |
| <p><img alt="TP=0 has GPU0 and GPU2 whereas CP=0 has GPU0 and GPU1 | |
| TP/SP shards the Q/K/V heads across TP ranks (in this example GPU0 and GPU2 get QKV_green, and GPU2 and GPU3 get QKV_blue), since each head can operate independently from others, we can apply ring attention within each TP rank" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2092.png" /></p> | |
| <p>TP=0 has GPU0 and GPU2 whereas CP=0 has GPU0 and GPU1 | |
| TP/SP shards the Q/K/V heads across TP ranks (in this example GPU0 and GPU2 get QKV_green, and GPU2 and GPU3 get QKV_blue), since each head can operate independently from others, we can apply ring attention within each TP rank</p> | |
| <p>Context parallelism is naturally compatible with data parallelism which splits the input along the batch size dimension.</p> | |
| <p>In fact, given an activation value of shape$[ \text{batch_size}, \text{sequence_length}, \text{hidden_dimension} ]$, data parallelism, sequence/context parallelism, and tensor parallelism split it across the 1st, 2nd, and 3rd dimensions, respectively, and these are independent of each other.</p> | |
| <h2>The nanotron FP8 recipe</h2> | |
| <p>However, through extensive experimentation, we identified two effective training recipes that allowed us to <strong>fully pretrain a 1B LLaMA model in FP8</strong>, covering both the forward and backward passes, while using an FP8 optimizer. More importantly, our approach successfully matched LLaMA-2’s pretraining learning rate. The result?</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2093.png" /></p> | |
| <p>A loss curve that perfectly matches mixed-precision bfloat16 (bfloat16 with FP32 master weights as the baseline). We successfully tested this to train a 1B LLaMA up to 100B tokens and a 7B LLaMA up to 25B tokens.</p> | |
| <p>Here’s what worked:</p> | |
| <ul> | |
| <li><strong>Recipe 1</strong>: Architectural tweaks, optimizer changes (without gradient clipping - rookie mistake), and an all-reduce modification.</li> | |
| <li><strong>Recipe 2</strong>: Remove all architectural tweaks from recipe 1.</li> | |
| </ul> | |
| <p>Recipe 1: Architectures and optimizer changes (without gradient clipping – silly mistake)</p> | |
| <ul> | |
| <li>Clipped softmax <a href="https://arxiv.org/abs/2306.12929">[Bondarenko et al]</a>: For an attention head to not update a token's representation, some attention heads allocate most of their attention probability mass to a fixed and common set of tokens with low information content. Since the output of the residual stream after MLP goes through LayerNorm, which is a normalization operation, and softmax assigns high attention probability based on the relative differences between attention scores, the previous layer should have a large outlier to maintain a high probability after normalization. This encourages MLP to maximize outliers. We equip attention with the ability to not update token representation without the need to maximize outliers.</li> | |
| <li>Layer-scale <a href="https://arxiv.org/abs/2103.17239v2">[Touvron et al]</a>: Add a trainable factor to scale down activations in the residual stream</li> | |
| <li>Update clipping <a href="https://arxiv.org/abs/1804.04235">[Shazeer et al]</a>: Adam optimizer uses two betas to track gradient trajectories, usually with a default beta1 of 0.9 and beta2 of 0.95. Since 0.9 < 0.95, the 2nd momentum has higher weights for past gradients than current ones. When gradient signals suddenly increase, for example, the second momentum no longer accurately represents the current gradient. A large number divided by a smaller one results in a number > 1, multiplying the learning rate results in a larger-than-desired update step size, therefore causing overshoot. We slow down the training by clipping the learning rate down if 2nd momentum shows it's outdated</li> | |
| <li>SmoothQuant <a href="https://arxiv.org/abs/2211.10438">[Xiao et al]</a>: It's been observed that activations are harder to quantize than weights due to outliers emerging during training. Thanks to the associative property of matrix multiplication, we can obtain a mathematically equivalent operation of a@b while transferring outliers of a to b by rescaling activations in a and multiplying the weight with the inverse of that value.</li> | |
| </ul> | |
| <p>Recipe 2: We added gradient clipping to the recipe 1, and remove all of its architectural changes and keep the following implementation details.</p> | |
| <p>Implementation details</p> | |
| <ul> | |
| <li>We also observe that outliers emerge significantly in some layers, specifically the last layer, and find that not quantizing the first and last layer to be critical. Other tricks: accumulate FP8 operations in bfloat16, keep the residual stream in float32, keep model activations, weights, and weight gradients in fp8e4m3, and input gradients in fp8e5m2, and optimizer states in fp8e4m3</li> | |
| <li>The all-reduce fix - preventing underflow in FP8: <a href="https://arxiv.org/abs/2309.14322">[Wortsman et al.]</a> also has shown that as model size increases, gradient magnitudes tend to shrink. PyTorch’s default all-reduce implementation uses pre-scaling, where gradients are averaged before being summed across data parallel workers: $g=\frac{g_1}{N}+\frac{g_2}{N}+\cdots+\frac{g_N}{N}$. However, when the data parallel (DP) size is large, this division can sometimes lead to underflow, especially in lower-precision formats like FP8, where gradient representation is already limited. To mitigate this, we switched to post-scaling all-reduce, where gradients are summed first and then divided after reduction $g=\frac{g_1 +g_2 +\cdots + g_N}{N}$.</li> | |
| <li></li> | |
| </ul> | |
| <h1>Overlapping computation and communication</h1> | |
| <p>We’ve now seen at least three example of overlapping communication with computation to improve the training efficiency:</p> | |
| <ul> | |
| <li>we overlapped the gradient all-reduce in data parallelism to sync gradients that were ready while the backward pass was progressing</li> | |
| <li>we overlapped the parameter gathering in ZeRO-3 with prefetching, making sure the weights are ready ahead of time for the computation</li> | |
| <li>and now in the ring attention mechanism</li> | |
| </ul> | |
| <p>The general idea is always the same: if there are parts we will need to communicate soon between workers and that are independent of the current computation, we can parallelize, or as it is also called overlap, the communication and computation. </p> | |
| <p>Let’s take a moment to look better at this fundamental tool for distributed training and go over the example of Ring Attention using PyTorch Profiler. In its implementation, we can overlap the sending and receiving of key/value pairs with the computation of attention scores. What does this look like?</p> | |
| <p><strong>Non-overlapping:</strong> If we don't overlap the communication and computation, each computation (represented by the purple block) can only begin after the communication (green block) is complete and total time is the sum of communication and computation.</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2094.png" /></p> | |
| <p><strong>Overlapping:</strong> However, if we manage to launch communication and computation in parallel, we eliminate the waiting time! Now we can see that the computation (green block) is launched immediately, one after the other. In this case the total time is <em>only</em> the sum of computations.</p> | |
| <p><img alt="image.png" src="The%20Ultra-Scale%20Playbook%20Training%20LLMs%20on%20GPU%20Clus%20af1b4137215e4e4eb1971e7dfa3185a9/image%2095.png" /></p> | |
| <p>Context parallelism has helped us going past the intra-node interconnect bottleneck, which blocked us from scaling TP across nodes. However, as you probably noted, it only helps reducing the memory constraints if the activation memory dominates the memory budget due to long sequences. What if we are not working on super long sequences and the model weights alone are too big for a single node?</p> | |
| <p>Well it turns out we have an other –quite different– option called pipeline parallelism (PP) which the time has come to explore now.</p> | |
| <p>[TODO: comment from Nouamane on comms overlapping with DP 512]</p> | 

