The Ultra-Scale Playbook: Training LLMs on GPU Clusters

FineWeb

Fueled by the scaling laws, the trend of training ever larger language models on vaster amounts of data has been driving progress in AI for the past couple years. Initially, the development of the largest models happened exclusively behind closed doors of a handful of research labs but recently opened up more with the release of models such as Llama 3.1 405B and DeepSeek R1. While these models have openly shared weights and their training recipes are described in technical reports, the challenging engineering to involved to train at the necessary infrastructure scale is still hidden between the lines of a handful of papers and complex training frameworks. This ~~long blog post~~ open-source book is here to open this black box!

In this book we invite you to follow us in the wonderful world of scaling training of Large Language Models to tens, hundreds, thousands of GPUs. It assumes you know the basics on LLM architecture and training, but are new to distributed training. This writing can be seen as a second part of a trilogy following our first blog on processing data for pre-training, the so-called “FineWeb blog post”. Having read both blog posts, you should have almost all the core knowledge needed to deeply understand how LLMs are being built nowadays, just missing a bit the final spices like data mixing or architecture choices to complete the recipe (stay tuned…).

Pre-training LLMs from scratch now requires amounts of compute which exceed in almost every case the use of a single GPU or machine. The clusters used to train these models range from hundreds to thousands of nodes each usually equipped with 4 to 8 GPUs. To make the best use of such an expensive hardware as well as to train in a reasonable time, a range of distributed training methods have been developed with the goal of ensuring that GPUs are highly utilized at all times. Efficiently scaling LLM training is also not confined to pretraining anymore, as fine-tuning larger models on more domain specific data is becoming the standard practice to achieve the best results.

In this post we’ll cover these scaling methods exhaustively while keeping a single story-line to understand where each technique comes from. We’ll cover data, tensor, pipeline and context parallelism as well as ZeRO and kernel fusion. The post is built on the following three foundations:

Quick intros on theory and concepts: before diving into code and experiments, we want to understand how each method works at a high level and what it’s advantages and limits are. You’ll learn about which parts of a language model eat away your memory and when during training it happens. You’ll learn how we can solve memory constraints by parallelizing the models and increase the throughput by scaling up GPUs. As a result you'll understand how the following widget to compute the memory breakdown of a transformer model works:

While this widget gives a theoretical breakdown the following tool can be used to predict the memory usage:

image.png

Clear code implementations: theory is one thing, but we discover all kinds of edge cases and important details when we implement something. That’s why we link to implementation references where possible. Depending on the case, we’ll use two code references: the picotron repository is built for education, thus it implements concepts usually in single, self-contained short files. On the other hand, to look at production ready code, we’ll refer to the nanotron implementations which is a production training codebase used at Hugging Face.

Picotron implements each key concept in a self-contained way, such that the method can be studied separately and in isolation.

Real training efficiency benchmarks: Finally, how to actually scale your LLM training depends on your infrastructure, such as the kind of chips, interconnect etc., and we can’t give a single unified recipe. What we will give though is a way to benchmark several setups and it is what we have done on our cluster! We ran over 4100 distributed experiments with up to 512 GPUs to scan many possible distributed training layouts and model sizes. TODO: link to dataset too

An overview of the over 4000 experiments across all Llama architectures where each data point corresponds to an experiment launch.

As you can see, there’s a lot of ground to be covered. Before getting into the trenches of distributed training let’s take a quick high level look on we’ll cover in the post.

TL;DR

This book is very extensive so we decide to start with a very general overview of how you can think about distributed training. At a high level, the key challenge in scaling LLM training is to make a training step (forward/backward/optimizer step) with a large batch size the fastest possible.

When scaling up models and input batches, we quickly end up in situations where either our target batch size won't fit in memory, or/and the model itself is too large to fit in a single GPU's memory.

To solve this scaling issue we’ll need to carefully evaluate different parallelization strategies and find the optimal balance between three main factors:

  1. Memory Usage
    • Hard limitation - if a training step doesn't fit in memory, training cannot proceed
    • Sometimes we can increase compute (e.g. recomputation) or increase communication (e.g. ZeRO) to reduce memory usage
  2. Compute Efficiency
    • Memory transfer can also decrease compute efficiency.
    • We want our hardware to spend most time computing, so we need to reduce time spent on data transfers or unoptimized kernels.
    • GPUs need sufficient workload (large enough matrices/batch sizes) to maintain high utilization (compute-bound) otherwise they become memory-bound (limited by memory bandwidth).
  3. Communication overhead
    • Two main types. For GPUs: intra-node (NVLink TODO: bandwidth) and inter-node (network TODO: bandwidth)
    • Two main attributes: base latency and peak bandwidth. Base latency is a constant overhead that makes us want to do the least number of comms possible, and peak bandwidth controls the how fast we can move data between gpus
    • We prioritize using the fastest communication channels (like NVLink) for operations that occur frequently and/or block computation (e.g. tensor parallelism)
    • We want to minimize communication overhead as it keeps GPUs idle, so we try to overlap communication with compute as much as possible

But let’s not get too much ahead of our self and scale progressively. To guide you along the journey and as a practical reference we summarized the key concepts in a cheatsheet:

[TODO: ADD CHEATSHEET]

Now that we nailed a few key concept and terms let’s get started by revisiting the basic training steps of an LLM!

First Steps: Training on one GPU

Let’s start by quickly reviewing the very basics of model training before we start to scale to many GPUs. When a model is trained on a single GPU, the training typically consists of three steps:

  1. a forward pass which passes inputs through the model to yield its outputs,
  2. a backward pass to compute the gradients, and
  3. an optimization step using the gradients to update the parameters

It looks generally like this:

image.png

In this figure, the boxes on the top line can be seen as successive layers inside a model (same for the last line). The red boxes are the associated gradients for each of these layers, computed during the backward pass.

The batch size (bs) is one of the important hyper-parameters for model training and affects both model convergence and throughput.

If the batch size is too small, gradients will tend to be noisy and the model may not be able to converge to the most optimal performance, on the contrary it can be useful in early training to navigate quickly in the training landscape. On the other hand, a batch size too large will make less use of each training token rendering convergence slower and wasting compute. You can find a nice discussion of this topic in OpenAI’s paper on large batch training or Section 4.2 of MiniMax-01 technical report.

Batch size also affects the time it takes to train on a given text dataset: a small batch size will require more optimizer steps to train on the same amount of samples. Optimizer steps are costly (in compute time) and the total time to train will thus increase compared to a larger batch size. This being said, note that the batch size can often be adjusted quite largely around the optimal batch size without major impact to the performance of the model, i.e. the sensitivity of final model performances to the exact batch size value is usually rather low around the optimal batch size.

In the LLM pretraining community, batch sizes are commonly reported in terms of tokens rather than in number of samples (bst = Batch Size Tokens), this makes training numbers generally independent of the exact input sequence length used during the training.

In the simplest case, training on a single machine, the bs (in samples) and bst can be computed from the model input sequence length (seq) as follows :

bst=bs *seq

A sweet spot for recent LLM training is typically on the order of 4-60 million tokens per batch. However, a typical issue when scaling the training of our model to these large batch sizes is out-of-memory issues, ie. our GPU doesn’t have enough memory.

It’s time to tackle our first scaling problem: what if our model starts exploding GPU memory before we’ve reached our target batch size (maybe in some case even when using the lowest possible batch size, bs=1)?

Let’s start by quickly understanding what led to our out-of-memory issue in the first place. This will help us gain some useful intuitions for later.

Memory usage in Transformers

When training a neural network model, one store several items in memory:

These items are stored as tensors which come in different shapes and precisions. The shapes are determined by hyper-parameters such as batch size, sequence length, model hidden dimensions, attention heads, vocabulary size, and potential model sharding as we’ll see later. Precision refers to formats like FP32, BF16, or FP8, which respectively require 4, 2, or 1 byte to store each single value in the tensor.

So how can I quickly determine memory usage from these variable? One simple way is to do this empirically and just measure it.

Memory profiling a training step

Using this snippet [TODO: link to appendix A5], we can understand how memory is allocated throughout training. We can see that memory utilization is not a static thing but varies a lot during training and during a training step:

llama-1b-memory.png

Clearly the first step looks very different from the subsequent ones, but let’s first have a look at the general anatomy of a step: first the activations increase quickly as we do the forward pass, then during the backward pass the gradients build up and as the backward pass propagates, the stored activations used to compute the gradients are progressively cleared. Finally, we perform the optimization step during which we need all the gradients and then update the optimizer states before we start the next forward pass.

Why does the first step looks different: the activations increase quickly and then plateau for a while. In this first step the torch cache allocator does a lot of preparation preparing memory allocations to speed up the subsequent steps so that they don’t require searching for free memory blocks afterwards (see Zach’s blog). After the first step we also see the optimizer states appearing which generally offset the memory usage for further training steps.

Now that we’ve a first view of memory, let’s see how scaling up training is often a question of maximizing compute efficiency while keeping the memory requirements of these various items (activations, parameters, gradients, optimizer states) within the memory constraints of the GPUs.

Weights/grads/optimizer states memory

We can actually pretty easily estimate the memory needed for the model’s weights, gradients and optimizer states.

For a simple transformer LLM the number of parameters is given by the following formula:

N = h * v + L * (12 * h^2 + 13 * h) + 2*h

In that equation, h is the hidden dimension, v the vocabulary size, and L the number of layers in the model. Note that looking at the equation we can see that the term that will dominate at large hidden dimensions is the h^2 term since it’s the only one growing quadratically as we scale the parameters.

Memory requirements for the parameters and gradients are simply the number of parameters multiplied by the number of bytes per parameter. In good old-fashioned full precision (FP32) training both parameters and gradients require 4 bytes while the optimizer, if we use Adam, requires the momentum and variance to be stored, which adds another two 4 bytes per parameter. In summary:

\begin{aligned} & m_{params} = 4 * N \\ & m_{grad} = 4 * N \\ & m_{opt} = (4+4) * N \end{aligned}

Now let’s have look how things change if we train with mixed precision. The default nowadays is for mixed precision training is BF16, requires 2 bytes per parameter and gradient as well as an additional copy of the model weights and gradients in FP32, thus 12 bytes per parameter in total. In addition to the parameters and gradient, we need to store the optimizer states: for the Adam optimizer, this requires the momentum and the variance usually stored in FP32 for numerical stability, each using 4 bytes.

Here’s the summary:

\begin{aligned} & m_{params} = 2 * N \\ & m_{grad} = 2 * N \\ & m_{params_fp32} = 4 * N \\ & m_{opt} = (4+4) * N \end{aligned}

Interestingly, mixed precision itself doesn’t save overall memory as it just distributes the memory differently across the three components, and in fact adds another 4 bytes over full precision training if we accumulate gradients in FP32. It’s still advantageous as having the model which does the forward/backward in half precision it allows us to (1) use optimized lower precision operations on the GPU which are faster and (2) reduces the activation memory requirements during the forward pass.

Let’s get a sense of how much general memory we need for a model (full and mixed precision giving the same overall value):

Model parameters FP32 or BF16 w/o FP32 grad acc BF16 w/ FP32 grad acc
1B 16 GB 20 GB
7B 112 GB 140 GB
70B 1120 GB 1400 GB
405B 6480 GB 8100 GB

As we can see, as soon as we reach 7B (!), weights and optimizer requirements already starts to add up significantly and exceed the size of a typical GPU memory, e.g. 80GB for a H100 GPU.

But for now, let’s start with models which still fits in a single GPU, take a look at the other big contributor to our memory budget: the activation memory.

Activations memory

Activation memory is a bit more complex to compute than the weights, gradients and optimizer states, in part because it depends on the inputs of the model. If you’re unsure why we even need to store activations for the backward pass, this reference is a good quick refresh. After a careful inspection of how backward pass is computed we can estimate the total memory required for the activations in mixed precision and we arrive at the following equation:

m_{act} = L seq * bs * h * (34 + \frac{5n_{heads}*seq}{h})

Here L is the number of layers, seq the sequence length, bs the batch size in samples, h the hidden dimension of the model and n_{heads} the number of heads.

For the exact derivation of the numbers, you can follow this original NVIDIA paper on recomputation , it essentially requires you to do some accounting of all the sizes of intermediate activations between each operation in a transformer layer.

An interesting observation here is how the memory is not static for a given model but it scales linearly with both the sequence length and batch size. This means the activation memory is the part which will blow up when we increase our batch size or train with longer sequences. We can use this equation to look at how memory usage changes for various sequence lengths for example for Llama models (bs=1):

llama-memory-bars-no-recomp.png

This graph tells a striking story: for short sequences (or similar for small batch-sizes), activations are almost negligible, but starting at around 2-4k tokens they come to take a significant amount of memory while parameter, gradient and optimizer states usage (that we’ll discuss later) stays roughly independent of the sequence length and batch size.

For large input tokens (a.k.a large batch-sizes/sequences), activations become by far the largest memory burden.

Is there a way to tame this “activation explosion”? Good question, reader!

It’s time to explain our first technique – called activation recomputation **which will help us cap activation memory footprint. An essential tool in today’s large model training toolbox.

Activation recomputation

The general idea behind activation recomputation – also called gradient checkpointing or rematerialization – is to discard some activations during the forward pass to save memory and spend some extra compute to recompute these on the fly during the backward pass. Without recomputation, we store every hidden state between two learnable operations (e.g. FF, LayerNorm etc.), such that we can use them during the backward pass to compute gradients. When we use recomputation we typically will only store activations at a few key points along the model architecture, discard the rest of activations and recompute them on the fly during the backward pass from the nearest saved activations, basically performing again a sub-part of the forward pass to trade of memory for compute. It generally looks like this:

image.png

There are several strategies to select key activations to store:

Let’s see how drastically recomputation strategies can in practice reduce the memory footprint and how selective recomputation strikes a nice balance between memory saving and recomputation cost:

llama-8b-memory-bars--recomp.png

Most training frameworks these days use FlashAttention (which we’ll cover a bit later) which integrate natively activation recomputation in its optimization strategy by recomputing attention scores and matrices in the backward pass instead of storing them. Thus most people using Flash Attention are already making use of selective recomputation.

As you’ve now understood, activation recomputation increases the number of FLOPs slightly due to recomputation, while it significantly reduces memory access overhead.

This trade-off is particularly advantageous on hardware with small high-speed memory, like GPUs, as accessing memory is typically slower than performing computations. Despite the additional operations involves, the overall effect is thus often faster computation as well, in addition to the much lower memory footprint.

Now that we’ve learned about recomputation, we can tame the activations memory usage as we saw in the above graphs!

However, activations still bears a linear dependance on the batch size and all our profiles in the barplots above were using bs=1 so as we move to larger batch sizes it might become an issue again. Do not despair as we have a second tool in our box - gradient accumulation to the rescue!

Gradient accumulation

Now that we’ve used activation recomputation to fit our model with a small batch size on a single GPU, we still need to reach our target batch size, let’s say 1M tokens (see our earlier discussion on optimal batch size). Gradient accumulation is a very straightforward method to avoid memory explosion when doing this.

With gradient accumulation we split our batch into micro-batches, do forward and backward passes repeatedly on each micro-batch, compute the gradients, and, as the name suggests, sum the gradients for each micro-batch before doing a final optimizer step. In practice, we perform the optimization step not on the sum but on the average of the gradients, so the result is independent of the number of gradient accumulation steps.

Let’s call the batch size for each forward pass the micro batch size (mbs). We’ll refer to the overall batch size between each optimizer step as the global batch size (gbs). If we do one optimizer step for each 8 forward/backward passes, the global batch size will be 8 times the micro batch size.

What we now call global batch size thus corresponds to what we’ve called up to now just batch size for simplicity (we now make our terms more precise to avoid ambiguity).

With gradient accumulation the global batch size can be simply computed as follows:

bs = gbs = mbs \times grad\_acc

Gradient accumulation allows us to effectively increase our batch size up to infinity (and beyond!) while the memory footprint stays constant. Gradient accumulation is also compatible with activation recomputation for further memory reduction. One drawback however, is that gradient accumulation requires multiple consecutive forward/backward passes per optimization step thereby increasing the compute overhead and slowing down training. No free lunch!

image.png

Gradient accumulation allows us to reduce memory of activations which grow linearly with batch size by computing only only partial, micro-batches. There is a small overhead caused by the additional forward and backward passes.

But if you’ve carefully followed, you probably noticed that the forward/backward passes for each micro-batch can actually be run in parallel. Forward/backward passes are independent from each other, with independent input samples being the only difference. Seems like it’s time to start extending our training to more than one GPU!

Let’s get a larger workstation 🖥️ with a couple of GPUs and start investigating our first scaling technique called data parallelism which is just a parallel version of gradient accumulation.

TODO: add profiling here or not?

Data Parallelism

The idea behind data parallelism (DP) is to replicate the model on several GPUs (we call the replica's “model instances”) and run forward and backward passes on different micro batches of data in parallel for each GPU, hence the name Data Parallelism.

Using a different micro batch for each GPU means we’ll have different gradients in each GPU, so to keep the model instances in sync across different GPUs, the gradients from the model instances are averaged using an operation called “all-reduce”, which happens during the backward pass, before the optimizer step.

image.png

This involves our first “distributed communication” primitive: all-reduce which handles the synchronization and communication between GPU instances and nodes.

TODO: embed naive DP: https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/data_parallel/data_parallel.py#L10-L60

TODO: embed bucket DP: https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/data_parallel/data_parallel.py#L62-L171

image.png

A naive DP implementation would just wait for the backward pass the finish so that we have all gradients, then it triggers an all-reduce over all DP ranks, to sync these gradients. But such an sequential steps of computation followed by communication is A BIG NO! Because we don’t want our GPUs to stay idle while communication is happening.

Instead we should try to overlap communication and computation whenever possible so that they happen at the same time as much as possible.

Let’s see three optimizations that are done in practice for this!

First optimization: Overlap gradient synchronization with backward pass

The main drawback of the naive DDP approach we’ve just described is that after the backward pass (computation), we have to wait for gradient synchronization (communication) before updating the parameters. Could we overlap this communication with our computation? The answer is yes!

As shown in the figure above, the gradients (red boxes) for a layer can be gathered and summed even before the gradients from earlier layers (red boxes to the left) have been computed. For example, as soon as the backward pass of the last layer is complete (last box on the right), those gradients can already be gathered and summed while the backward computations continue for earlier layers, moving toward the left.

This can be achieved in pytorch by attaching an all-reduce hook function to each parameter. An all-reduce operation is triggered as soon as the gradient for that parameter is ready, while gradients for other parameters are still being computed. This approach overlaps most of the all-reduce operations with gradient calculations, thereby improving efficiency. Here's a simple function to attach a hook:

def register_backward_hook(self, hook): """ Registers a backward hook for all parameters of the model that require gradients. """ for p in self.module.parameters(): if p.requires_grad is True: p.register_post_accumulate_grad_hook(hook)

image.png

Overlapping computation and communication reduces the time spent waiting for gradient synchronization across the entire model. Gradient synchronization can occur (at least partially) in parallel with backward pass, significantly speeding up data parallelism.

This is our first example of “overlapping computation and communication” which we will discuss several times in this blog post and is an essential technique to maximal scaling efficiency. Let's have a look how we can further improve the DP efficiency!

Second optimization: Bucketing gradients

We can even go further with optimizing DP. For a given number of parameters to synchronize, GPU operations like collective communications are often more efficient when performing few calls on large tensors rather than many calls on smaller tensors. Therefore, instead of performing independent all-reduce for each gradient, we can group gradients into buckets and launch a single all-reduce for all the gradients within the same bucket. Think of it like packing items into boxes before shipping—it's more efficient to send a few big boxes than many small ones. By performing a single all-reduce operation for each bucket, we can significantly reduce communication overhead and speed up the communication operation.

image.png

Third optimization: Interplay with gradient accumulation

As we’ve seen before, gradient accumulation works by performing multiple forward and backward passes before updating the parameters with optimizer.step(). When combining gradient accumulation with data parallelism, we should be careful when we want to synchronize gradients.

In a naive version, an all-reduce operation is automatically triggered after each backward pass during the accumulation, which is sub-optimal as a single reduce after the final step would have the same effect while reducing overhead.

In PyTorch, this is typically solved by adding a model.no_sync() decorator, which disables gradient synchronization, on the backward passes which don’t need reduction.

Revisit global batch size

Let’s update our batch size equation with our newly learned Data Parallelism and Gradient Accumulation parameters:

bs = gbs = mbs \times grad\_acc

Where grad\_acc is the number of gradient accumulation steps and DP is the number of parallel instances used for data parallelism.

Given a targeted global batch size, we can thus trade gradient accumulation steps for data-parallel processes to speed up training. In practice, people tend to maximize the number of data-parallel nodes (DP) over gradient accumulation as much as possible since it's inherently parallel, unlike the sequential nature of gradient accumulation. Gradient accumulation is then added on top of data parallelism to achieve the target global batch size when scaling data parallelism alone is not sufficient before you run out of GPUs.

Being able to distribute the training over different samples gives us a first dimension of parallelization, thus making this 1D parallelism (we’ll progressively cover 4 more dimensions).

Our journey up to now

Let’s quickly summarize what we’ve seen up to now and how to setup our first 1D parallel training with a draft recipe for an optimal data-parallel setup:

  1. We should first determine the best (global) batch size in tokens (GBST) either by consulting literature or running experiments measuring model convergence.
  2. We then select a sequence length for training, again by either consulting literature or running experiments. Generally, 2-8k tokens work reliably well for the evaluations we have today (we won’t dive in training recipes here but teams usually increase the sequence at the end of the training, adding some longer-context data samples in the mix to reach the longer context size of today).
  3. We now know the batch size (gbs). We can find the maximum local batch size (mbs) on a single GPU by increasing the local batch size until we run out of memory.
  4. Finally, we determine the number of available GPUs for our target DP. The ratio of GBS to DP gives us the remaining number of gradient accumulation steps needed for the desired GBS.

If the gradient accumulation ratio is lower than one, i.e. we have too many GPUs a.k.a GPU-rich 🤑 (!), we can either choose to not use all our GPUs, explore a larger global batch size or test if a lower MBS will speed up training. In the latter case we’ll end up prioritizing throughput over individual GPU compute efficiency, using a smaller MBS than possible in order to speed up training.

Time to take a concrete example: Let’s say we want to train a recent model with a GBS of 4M tokens and a sequence length of 4k. This means our batch size will be 1024 samples (we pick powers of two). We observe that a single GPU can only fit MBS=2 in memory and we have 128 GPUs available for training. This means with 4 gradient accumulation steps we’ll achieve our goal of 1024 samples or 4M tokens per training step. Now what if we suddenly have 512 GPUs available? We can achieve the same GBS and thus identical training by keeping MBS=2 and setting gradient accumulation steps to 1 and achieve faster training!

While data parallelism cleverly overlaps the all-reduce gradient synchronization with backward computation to save time, this benefit starts to break down at large scales. As we add more and more GPUs (hundreds or thousands), the overhead of coordinating between them grows significantly. The end result? We get less and less efficient returns from each additional GPU we add to the system:

image.png

As expected, we can also see that the memory usage per GPU is not affected by adding more DP ranks for training.

We’ve explored data parallelism, our first (simple) strategy to scale training across more GPUs. It works like gradient accumulation but parallelizes the forward and backward passes on micro batches, thus increasing throughput!

The keen reader has already probably noted however that this assumes that we can fit at least one input sample forward pass (mbs=1) into our GPU memory. This is not always the case! As we can see, larger models don’t fit into a single GPU, even with activation recomputation activated:

image.png

Do we have other options for these larger models? We do have some solutions thankfully. They will involve either move some of these tensors to the CPU or split the weights/gradients/optimizer-states tensors across GPUs devices!

There are two main approaches to splitting: parallelism (tensor, context, or pipeline parallelism) and sharing (DeepSpeed Zero or PyTorch FSDP). Both approaches are somewhat orthogonal and can actually be combined! The sharing paradigm is closely related to DP so we’ll have a look at it first by investigating the ZeRO method!

ZeRO (Zero Redundancy Optimizer)

In this section we will introduce DeepSpeed ZeRO (Zero Redundancy Optimizer), a memory optimization technology designed to reduce memory redundancies in LLM training.

While Data Parallelism is very efficient in scaling training, the naive replication of optimizer states, gradients, and parameters across each DP rank introduces a significant memory redundancy. ZeRO eliminates memory redundancy by partitioning the optimizer states, gradients, and parameters across the data parallel dimension, while still allowing computation with the full set of parameters. This sometimes requires more communications between DP ranks which may or may not be fully overlapped as we’ll see next!

This approach is organized into three possible optimization stage of ZeRO:

You might be missing the activations among the things we can shard. Since each DP replica of the model receives a different micro-batch the activations on each DP rank also differ so they are not duplicated and thus can’t be sharded!

Let’s have a closer look how much we can save with the partitioning of each ZeRO stage!

Memory usage revisited

Let’s first recap the memory usage of optimizer states, gradients, and parameters during a standard training. Let’s define the number of our model's parameters as \Psi (previously N but here we use the original ZeRO notation). In mixed-precision training with the Adam optimizer, the memory usage for each item we need to store is:

If we don’t accumulate gradients in fp32 this gives us a total memory consumption of 2\Psi + 2\Psi + 12\Psi, and if we accumulate it would be 2\Psi + 6\Psi + 12\Psi. Let’s focus for now on the case without fp32 gradient accumulation for simplicity but you can just add the additional bytes to the gradient term which are affected by ZeRO-2 and 3.

The idea of ZeRO is to shard these objects across the DP ranks, each node only storing a slice of the items which are reconstructed when and if needed, thereby dividing memory usage by the data parallel degree N_d:

image.png

Memory consumption of DP and three stages of Zero-DP. \Psi denotes number of parameters, k denotes the memory multiplier of optimizer states (k=12 for Adam), and N_d denotes DP degree.

Let’s explain this graph and it’s values by exploring how each ZeRO stage works. We’ll start with ZeRO-1.

ZeRO-1: Partitioning Optimizer States

In vanilla DP, all ranks gather the same gradients after the backward pass and simultaneously perform identical optimizer steps. This seems like a lot of duplicated work. Can we avoid it and reduce memory usage at the same time?

In ZeRO-1, the optimizer states are partitioned into N_d equal parts where N_d is the DP degree. This means that each model replica that’s distributed on each DP rank only keeps track of \frac{1}{N_d} of the optimizer states. During the optimization step only \frac{1}{N_d} of the float32 weights are updated, which we cast to get the corresponding \frac{1}{N_d} portion of the bfloat16 parameters.

However for the forward pass, we need all our bfloat16 parameters, we thus need to add an additional all-gather (the second type of collective communication primitive we encounter!) after the optimizer step so that each model replica has the full set of updated weights.

This explains the memory formula of 2\Psi + 2\Psi + \frac{k\Psi}{N_d} that we saw on the above graph! Here’s a summary of the sequence of operations for a single training step

See the figure below for all the necessary steps in one forward/backward pass cycle:

image.png

So in practice, compared to vanilla DP, Zero-1 adds an all-gather over all parameters after the optimizer step as we can see below:

image.png

If you've been following along, you'll recall from vanilla DP that we can overlap the all-reduce gradient communication with the backward pass computation. In ZeRO-1, we can also investigate how to efficiently overlap the newly added all-gather of bf16 parameters. There are two main strategies for this:

In ZeRO-1 the optimizer states have been partitioned, which means that each replica only updates \frac{1}{N_d} of the optimizer states. The keen reader must have noticed that there is no real need to have all gradients on all DP ranks in the first place since only a subset is needed for the optimization step. Meet ZeRO-2!

ZeRO-2: Adding Gradient Partitioning

The idea of ZeRO to is to not only shard the optimizer states but also the gradients. We actually only need the gradient shard corresponding to the optimizer state shard, so it makes sense to shard both the same way. [TODO: update] During the backward pass, instead of performing an all-reduce over the gradients, we only perform a reduce-scatter operation! Where we only spread the \frac{1}{N_d} gradients needed in memory, thus saving more memory compared to ZeRO-1.

image.png

It’s easy to see now that sharding the gradients leads to to 2\Psi + \frac{2\Psi+k\Psi}{N_d} and as N_d is increased we can save up to 8x memory over the baseline. In terms of communication the same process applies as for ZeRO-1, with the only difference that we communicate and release on the fly. In total, ZeRO-2 is thus also equivalent to vanilla DP training w.r.t. communication.

In terms of communication ZeRO-2 is similar to ZeRO-1, they both require a reduce-scatter for the gradients, and an all-gather over all parameters.

image.png

Now that we’ve sharded gradients as well, are we done or can we keep getting away with this? Well, sort of. We would like to reduce the memory of the parameters as well, and we’ve seen that we don’t need to wait for the entire all-gather to start the forward, we can already start the forward once we get the first layer.. here comes ZeRO-3!

ZeRO-3: Adding Parameter Partitioning

For Stage 3 we extend the above approach of sharding optimizer states and gradients over DP replicas up to sharding the model’s parameters.

So how do we do a forward or backward pass in practice if all parts of the model are distributed? Quite simply we gather them on-demand when we need them. In the forward pass this looks as follows:

image.png

So as we perform the forward pass and sequentially go through the layers we retrieve the necessary parameters on demand and immediately flush them from memory when we don’t need them anymore. The backward pass works the same way just inverted in flow and we produce the gradient shards:

image.png

During the forward pass we do all-gather operations for the parameters when we need them, so a \Psi communication tax. Since we discard the parameters immediately after we needed them in the forward pass we need one more all-gather during the backward pass as well incurring another \Psi in communication tax. Finally we need the same ***reduce-scatter*** as in ZeRO-2 for the gradients which costs also \Psi in communication and we arrive at a total communication cost of 3\Psi, compared to 2\Psi for Zero-2.

Thankfully, although we added many more communication operations, **prefetching** helps us overlap them efficiently by all-gathering weights for *Layer n+1* while we do the current forward for Layer n in the forward, and similarly, by all-gathering weights for Layer n-1 while doing the backward for Layer n. Of course this overlap only holds true as long as we don’t scale DP too much. (as a rule of thumb DP shouldn’t exceed 512)

In terms of memory we can see that our equation now reached it’s final form of \frac{2\Psi +2\Psi+k\Psi}{N_d} which means we can drive memory usage down indefinitely if we can increase the DP rank, at least for the model related parameters. Notice how it doesn’t help with the intermediate activations, for that we can use activation checkpointing and gradient accumulation as we’ve seen in earlier chapters.

Let’s summarize our journey into DP and ZeRO so far: we have seen that we can increase throughput of training significantly with DP, simply scaling training by adding more model replicas. With ZeRO we can train even models that would ordinarily not fit into a single GPU by sharding the parameters, gradients and optimizers states across DP, while incurring a small communications cost.

However, there is a limit here, DP only works if a layer of the model fits in a single GPU and ZeRO can only partition the parameters, gradients, and optimizer states, but not the activation memory! Recall from the activation memory discussion that it scales with sequence length and batch size. Naturally we could just limit those, but in practice we don’t want to be limited by hardware to train with only with a short sequence length.

image.png

Now that we've efficiently used the DP axis to reduce memory through efficient communication patterns, let's explore a new, orthogonal axis of parallelism - Tensor Parallelism. Unlike ZeRO3 that relies on heavy parameter communication, TP manages to shard parameters, gradients, optimizer states AND activations across devices without requiring any model parameter movement between GPUs. What! How is this even possible?! Let's explore this seemingly magical approach together! 🙂

Tensor Parallelism

So we have sharded the model’s parameters, gradients and optimizers states with ZeRO but we hit a limit once activation memory overtakes our memory budget. Welcome Tensor Parallelism (TP), a method which shards weights, gradients, and optimizers states as well as activations and without the need to gather them all prior to the computation. Seems like a dream! Let’s first have a look at how Tensor Parallel works with simple matrix multiplications.

Tensor Parallelism leverages the mathematical properties of matrix multiplication A \times B. To understand how it works, let's examine two fundamental equations that make this parallelization possible:

\begin{aligned} &\text{1.} \quad A\cdot B = A \cdot \begin{bmatrix} B_1 & B_2 & \cdots \end{bmatrix} = \begin{bmatrix} AB_1 & AB_2 & \cdots \end{bmatrix} \\ &\text{2.} \quad A\cdot B =\begin{bmatrix} A_1 & A_2 & \cdots \end{bmatrix} \begin{bmatrix} B_1 \\ B_2 \\ \vdots \end{bmatrix} = \sum_{i=1}^n A_i B_i \end{aligned}

This means that we can compute matrix product by either 1) multiplying each column of B individually or 2) multiplying each row individually and combining the results. In a neural network, the matrix multiplication is more often represented in the following format: X \times W, where:

In practice a small example of the operation looks like this:

image.png

Let’s see how we can parallelise this operation! In tensor parallelism, tensors will be split into N shards along a particular dimension and distributed across N GPUs. Matrices can be split either on the column part or row part leading to row and column parallelism. One thing we’ll see in the following is that choosing row or column sharding will require different communications primitives.

Our first option is to use column-wise sharding (also called column-linear): We'll copy the complete input matrices to each worker, requiring an operation called broadcast, and split the weight matrix into columns. The inputs are then multiplied with the partial weight matrices, and the results are finally combined using an all-gather operation.

image.png

The second option is called row-wise sharding (also called row-linear): As the attentive reader might guess, row-linear means that we split the weight matrix into chunks of rows. However, this also requires us to split the inputs, which needs a scatter operation rather than a broadcast as used in column-linear sharding. The results on each worker are already in the right shape but need to be summed for the final result, thus requiring an all-reduce operation in this scenario.

We see here our fourth distributed primitive: scatter!

image.png

Tensor Parallelism in a Transformer Block

To come up with a strategy to follow, let’s move from a toy example to a real model building block. A Transformer model is made of two main building blocks : Feedforward layers (MLP) and Multi-Head Attention (MHA). We can apply tensor parallelism to both.

The Feedforward part can be parallelized by having a “Column linear” followed by a “Row Linear” which amounts to a broadcast to copy the input and an all-reduce in forward. Note that the broadcast isn’t needed in actual training where we can make sure inputs are already synced across TP ranks.

image.png

Now that we’ve found the most efficient schema for the Feedforward part of the transformer, let’s take a look at the multi-head attention block (MHA).

We can generally follow a similar approach where Q, K, and V matrices are split in a column-parallel fashion, and the output projection is split along the row dimension. With multi-head attention, the column-parallel approach has a very natural interpretation: each worker computes the attention for an individual or a subset of heads. The same approach works as well for multi-query (MQA) or grouped query attention (GQA) where key and values are shared between queries.

It's also worth noting that the tensor parallelism degree should not exceed the number of Q/K/V heads because we need intact heads per TP rank. And in case we’re using GQA, TP degree should be below number of K/V heads, otherwise it requires additional comms to keep them in sync. For instance, LLaMA-3 8B has 8 Key/Value heads, so the tensor parallelism degree should be less than or equal to 8, otherwise if TP=16 for example, we need to duplicate each K/V head and make sure they stay in sync.

image.png

Finally note that there is a tradeoff in terms of communication as we’ve added several distributed communication primitive directly in the computation path of our model. At the difference of ZeRO where we could prefetch, it can be harder to make these communication fully overlap with computations.

Forward pass in Tensor Parallelism

Looking at the timeline of operations in tensor-parallel MLP (same applies for Attention), we can better understand the tradeoffs involved. In the forward of each decoder layer, we hit a synchronization point with the AllReduce operation that cannot be overlapped with computation. This exposed communication overhead is necessary to combine partial results across tensor-parallel ranks before the final LayerNorm can be applied.

Tensor parallelism does help reduce activation memory for the matrix multiplications since the intermediate activations are sharded across GPUs. However, we still need to gather the full activations for operations like LayerNorm, which means we're not getting the full memory benefits we could. Additionally, it introduces significant communication requirements that heavily depend on the network infrastructure. The inability to hide this particular AllReduce behind computation means it directly adds to the critical path of forward propagation.

Impact of Tensor Parallelism on model performance and batch size capacity: while increasing TP leads to reduced per-GPU throughput (left), it enables processing of larger batch sizes (right), illustrating the trade-off between computational efficiency and memory availability in distributed training.

Impact of Tensor Parallelism on model performance and batch size capacity: while increasing TP leads to reduced per-GPU throughput (left), it enables processing of larger batch sizes (right), illustrating the trade-off between computational efficiency and memory availability in distributed training.

In practice, the communication overhead of tensor parallelism becomes particularly noticeable as we scale beyond 8 GPUs. While tensor parallelism within a single node can leverage fast NVLink interconnects, going across nodes requires slower network connections. As shown in the throughput plot above, we observe significant drops when moving from TP=8 to TP=16, and an even steeper decline from TP=16 to TP=32. This illustrates how communication costs can dominate at higher degrees of parallelism.

However, tensor parallelism provides important benefits for memory usage by distributing model parameters, gradients, optimizer states and activations (to some extent) across GPUs. Let's examine this effect on a 70B parameter model:

image.png

As we can see, increasing tensor parallelism reduces the memory needed for model parameters, gradients and optimizer states on each GPU. While tensor parallelism does help reduce activation memory in attention and feedforward layers by sharding the matrix multiplications across GPUs, we don't get the full memory benefits we could. This is because operations like layer normalization and dropout still require gathering the full activations on each GPU, partially negating the memory savings. We can do better by finding ways to parallelize these remaining operations as well.

This raises an interesting question - could we extend tensor parallelism to these remaining operations as well? Indeed, it's possible to parallelize layer norm, dropout and other operations too, which we'll explore next.

Sequence Parallelism

In regions where we apply tensor parallelism (TP), like attention and feedforward layers, each GPU only needs to operate on a portion of the hidden dimension since the weights are sharded. However, operations like layer norm or dropout (which is not used a lot anymore in LLM) require access to the full hidden dimension to compute correctly.

Rather than gathering the full hidden dimension on each GPU (which would defeat the memory benefits of TP), we can instead shard these operations along the sequence length dimension. This approach is called sequence parallelism (SP).

Sequence parallelism (SP) involves splitting the activations and computations for the parts of the model not handled by tensor parallelism (TP) such as Dropout and LayerNorm, but along the input sequence dimension rather than across hidden dimension. This is needed because these operations require access to the full hidden dimension to compute correctly. For example, LayerNorm needs the full hidden dimension to compute mean and variance:

\text{LayerNorm}(x) = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta

where \mu = \text{mean}(x) and \sigma^2 = \text{var}(x) are computed across hidden dimension h.

So even though these operations are computationally cheap, they still require significant activation memory since they need the complete hidden dimension. SP allows us to shard this memory burden across GPUs by splitting along the sequence dimension instead.

In practice we’ll go from the left diagram to the right:

 in forward: f = no-op ; f* = all-reduce ; g = all-gather ; g* = reduce-scatter
            in backward: f = all-reduce ; f* = no-op ; g = reduce-scatter ; g* = all-gather
           SP region needs full hidden_dim

in forward: f = no-op ; f = all-reduce ; g = all-gather ; g = reduce-scatter in backward: f = all-reduce ; f = no-op ; g = reduce-scatter ; g = all-gather SP region needs full hidden_dim

The diagram shows how we transition between tensor-parallel and sequence-parallel regions using different collective operations (labeled "f" and "g"). The key challenge is managing these transitions efficiently while keeping memory usage low and maintaining correctness.

In the forward pass:

In the backward pass:

These operations "f" and "f" are called conjugate* pairs because they complement each other - when one is a no-op in forward, the other is an all-reduce in backward, and vice versa.

For sequence parallelism (SP), we use different operations labeled "g" and "g*". Specifically, we avoid using all-reduce in the SP region since that would require gathering the full activations and increase our peak memory usage, defeating the purpose of SP.

image.png

So what is actually happening here? As a famous LLM would say, let’s take it step-by-step:

Initial LayerNorm (SP Region)

First Transition (SP → TP)

First Linear Layer (TP Region)

Second Linear Layer (TP Region)

Final Transition (TP → SP)

A key advantage of sequence parallelism is that it reduces the maximum activation size we need to store. In tensor parallelism alone, we had to store activations of shape (b,s,h) at various points. However, with sequence parallelism, the maximum activation size is reduced to \frac{b \cdot s \cdot h}{tp} since we always either split along the sequence or hidden dimensions.

It’s a bit difficult to keep track of all the parts that are sharded differently in TP and TP/SP - believe us, we find it hard to map as well so we made this small table to summarize how the activations (aka hidden_states ) shape change across hidden dimension h and sequence dimension s during a forward pass:

Region TP only TP with SP
Enter TP (Column Linear) h: sharded (weight_out is sharded)
s: full
h: sharded (weight_out is sharded)
s: all-gather to full
TP Region h: sharded
s: full
h: sharded
s: full
Exit TP (Row Linear) h: full (weight_out is full + all-reduce for correctness)
s: full
h: full (weight_out is full + reduce-scatter for correctness)
s: reduce-scatter to sharded
SP Region h: full
s: full
h: full
s: sharded

And for the embedding layer:

Region Vanilla TP TP with SP
Embedding Layer (Row Linear sharded on vocab) h: full (weight_out is full + all-reduce for correctness)
s: unchanged
h: full (weight_out is full + reduce-scatter for correctness)
s: reduce-scatter to sharded

You can find an example of implementation of both column and row linear TP in picotron: https://github.com/huggingface/picotron/blob/main/picotron/tensor_parallel/tensor_parallel.py

By using sequence parallelism, we can achieve even greater activation memory savings, allowing us to push our batch size and sequence length further than what would be possible with tensor parallelism alone. Let's see what that means for our previous 70B model example:

image.png

Does that mean that SP incurs more communication than TP? Well, yes and no. In the forward of a vanilla TP we had two all-reduce per transformer block, and in SP we have two all-gather and two reduce-scatter per transformer block. So SP does twice the number of communication operations as TP. But since an all-reduce operation can be broken down into to an all-gather + reduce-scatter (see in [TODO: Appendix link]) they’re actually equivalent in terms of communication. Same reasoning for backward as we just use the conjugate of each operation (no-op ↔ allreduce and allgather ↔ reducescatter).

If you’ve been paying close attention, you’ll notice that we’re talking about 4 comms ops in each layer (2 for Attention and 2 for MLP). This is how the MLP profiling looks like when using Tensor + Sequence Parallelism:

image.png

Besides the fact that TP requires communications in each layer, it also can’t easily be overlapped with compute, which makes throughput heavily dependent on the communication bandwidth. This is why TP is usually done only within a node (TP≤8).

As you might expect, this communication overhead becomes increasingly problematic as we scale up tensor parallelism. To illustrate this, let’s check throughput as we scale TP with SP for a 3B model:

image.png

Impact of combined Tensor and Sequence Parallelism (TP/SP) on a 3B model’s performance and memory utilization with 4096 seqlen: when scaling both TP and SP together, there's a trade-off between computational efficiency (left) and memory capacity (right). While higher parallelism degrees reduce per-GPU throughput, they enable processing of significantly larger batch sizes by reducing the activation memory.

Let’s summarize our observations:

We have seen how TP helps us shard activations across several GPUs by splitting the attention and feedforward operations along the hidden dimension and how SP is a natural complement for the remaining operations by splitting along the sequence dimension.

However, there are two limits to TP and SP: 1) if we scale the sequence length the activation memory will still blow up in the TP region and 2) if the model is too big to fit with TP=8 then we will see a massive slow-down due to the inter-node connectivity.

We can tackle problem 1) with Context parallelism and problem 2) with Pipeline parallelism. Let’s first have a look at Context parallelism!

Context Parallelism

With Tensor Parallelism and Sequence Parallelism, we can reduce the memory requirements per GPU significantly as both model weights and activations are distributed across GPUs. However, when training models on longer and longer sequences (e.g. when scaling to 128k or more tokens per sequence) we might still exceed the memory available on a single node, because inside the TP region we still have to process a full sequence length.

Even if we use full recomputation of the activations, which comes at a heavy compute overhead (30%), we still need to hold in memory some activations at the layer boundaries which scale linearly with sequence length:

image.png

Can we apply similar ideas to our sequence parallelism approach but inside in the modules where we apply Tensor Parallelism already, thereby also reducing the effect of sequence length? Yes, it’s time to talk about Context Parallelism, which you will find quite intuitive after all we’ve already convered.

Given that activations scale linearly with sequence length, can we find a way to consistently split activations along the sequence dimension throughout the entire model, while paying attention to the Attention blocks (haha.. funny jokes :D) that require access to the full sequence? This brings us to Context Parallelism, a natural extension of the concepts we've covered so far.

The idea of Context Parallelism is quite simple; just like Sequence Parallelism, we’ll split the input along the sequence dimension but we now apply this splitting along the full model, instead of only the sequence parallel regions of the model as we’ve done previous with Tensor + Sequence Parallelism.

image.png

Splitting the sequence doesn't affect most modules like MLP and LayerNorm, where each token is processed independently. It also doesn’t require expensive communication like TP, as only the inputs are split and not the weight matrices. Just like data parallelism, after computing the gradients, an all-reduce operation is initiated to synchronize the gradients across the context parallelism group.

There is one important exception though, which is the attention module. In this module each token needs to access key/value pairs from all other sequence tokens or in the case of causal attention at least attends to each previous token.

Because Context Parallelism splits the inputs along the sequence dimension across GPUs, the attention module requires full communication between GPUs to exchange the necessary key/value data.

That sounds very expensive if we do it naively. Is there a way to do this rather efficiently and fast! Thankfully there is: a core technique to handle this communication of key/value pairs efficiently is called Ring Attention.

Discovering Ring Attention

In this implementation of attention, each GPU first initiates a communication operation to send its key/value pairs to other GPUs. While waiting for the other GPUs data, it computes the attention score for the portion of the data it already has in memory. Ideally, a next key/value pair is received from another GPU before this computation finishes, allowing the GPU to start the next round of computation immediately after it finishes its first computation.

To illustrate this, let's suppose we have 4 GPUs and an input of 4 tokens. Initially, the input sequence is split evenly along the sequence dimension, so each GPU will have just one token along with its corresponding Q/K/V values. For example, Q1, K1, and V1 represent the query, key, and value of the first token, which are located on the 1st GPU. The attention calculation will take 4 time steps to complete. At each time step, each GPU follows these 3 stages:

  1. Send “current keys and values” to the next machine except during the last time step in a non-blocking manner so it starts the following step before this step is finished
  2. 2. Locally compute the attention score on the “current keys and values” it already has, which typically involves performing Softmax(\frac{QK^T}{\sqrt{d}}) * Vd-math>.
  3. Wait to receive keys and values from the previous GPU and then move to step 1 with “current keys and values” being now the key/values just received from the previous GPU.

The whole process with 4 GPUs is shown in the following animation:

image.png

With this animation, it’s also immediately clear why the authors chose to call this approach Ring Attention.

There is one big problem though which is that a naive implementation of Ring Attention lead to some strong imbalance between GPU streaming from the shape of the causal attention matrix. Let’s take a real look at what is happening in the SoftMax computation by considering the attention score matrix with the causal attention mask:

image.png

The SoftMax is computed row-wise, which means whenever a GPU has received all the tokens of a row it can be computed. We see that GPU1 can immediately compute it as it starts with tokens 1-4 and GPU1 actually doesn’t need to receive any information from any other GPUs. However, GPU2 will need to wait for the second round to also receive 1-4 and thus have all values for tokens 1-8. Also, GPU1 seems to perform much less work than all the other GPUs.

Let’s see if we can balance our computations better:

Zig-Zag Ring Attention – A Balanced Compute Implementation

We need a better way to distribute the input sequences. This can be achieved by assigning the tokens not purely sequential to the GPUs but by mixing the ordering a bit such that we have a good mix of early and late tokens on each GPU. This approach is called Zig-Zag attention and in this new arrangement, the attention mask will show an even distribution of computation but if you count the number of colored squares, you’ll see that the computation is now balanced across all GPUs.

image.png

At the same time we’ll also see that in order to complete all rows, each GPU will need information from all the other GPUs.

We have two general ways to overlap computation and communication, either by performing a general all-gather, regrouping all the KV on each GPUs at the same time (in a Zero-3 type of way) or we gather them one-by-one from each GPU to each GPU as needed:

image.png

image.png

The key difference between these two implementations lies in their communication patterns and memory usage:

1. AllGather Implementation:

2. All-to-All (Ring) Implementation:

The All-to-All approach generally offers better memory efficiency at the cost of slightly more complex communication patterns, while the AllGather approach is simpler but requires more temporary memory during the attention computation.

We've now seen how we can split a model across one node with TP to tame large models and that we can use CP to tame the activation explosion with long sequences. However, we saw that TP doesn't scale well across nodes, so what can we do if the model weights don't easily fit on 1 node? Pipeline parallelism to the rescue!

Pipeline Parallelism

In the TP section we saw that if we try to scale Tensor parallelism past the number of GPUs per single node (typically 4 or 8) we hit a lower bandwidth network called “inter-node connection” which can quite strongly impair our performances. We can see this clearly on e.g. the all-reduce operation when we perform it across several nodes:

image.png

Inter-node communication bandwidth measurements across different node counts, showing median (lines) and 5th-95th percentile ranges (shaded areas) for AllReduce, AllGather and ReduceScatter operations.

Sequence and context parallelism can help for long sequences but don’t help much if sequence length is not the root cause of our memory issues but rather the size of the model itself. For large model (70B+), the size of the weights alone can already push past the limits of the 4-8 GPUs on a single node. We can solve this issue by summoning the fourth (and last) parallelism dimension: “pipeline parallelism”.

image.png

Pipeline Parallelism is conceptually very simple –we’ll simply spread the layers of our model across GPUs – but the devil lies in implementing it efficiently. Let’s dive in it!

Splitting layers on various nodes - All forward, all backward

So, let’s say we simply spread the layers on several devices, e.g. a first GPU will take the first few layers and a second GPU will take the second part of the models and so on. The forward pass through our model now simply involves sequentially passing the batch of data along the model and thus successively using each compute device.

We have a direct first advantage: the required interconnect bandwidth stays quite low as we only send moderate-sized activations at a handful of location along the model depth. This is a huge difference e.g. compared to the communication in Tensor Parallelism, happening several times within each layer.

But maybe you start feeling a glimpse of the troubles to come: “sequentially” and “successively”?!? This doesn’t sound very efficient in the world of parallel computation, especially after our discussion about computation and communication overlap.

Indeed reader! The main challenge in pipeline parallelism will be how to efficiently circumvent the sequential nature of PP to keep our GPU busy at all times and avoid having one GPU computing while the others are waiting. Here is how our GPU utilization is looking when doing a naive and simple forward and backward pass through the model where the numbers indicate the model layers:

image.png

An example of Pipeline parallelism for a model with 16 layers distributed across 4 GPUs. The numbers correspond to the layer IDs.

The remaining idle time is indicated in grey and usually called the “bubble” and the sight of this probably break your heart after we spent so much time optimizing throughput.

We can quantify how efficient a pipeline setup is by looking at how much time we loose because of the bubble. Let’s say t_f and t_b are the times for the forward and backward pass, respectively, as measured for one microbatch and one stage of the pipeline (a simple assumption is often to have t_b \approx 2 \times t_f which you can see on the above graph). If we could perfectly parallelize the ideal total time would be t_{id}=t_f + t_b. However, we can count on the graph that due to the pipeline bubble there is additional time of t_{pb}=(p-1)*(t_f+t_b) (where p is the degree of pipeline parallelism, i.e the number of GPU on the above graph) ie. the time each GPU is waiting while other GPUs are computing.

We can compute the ratio of the additional bubble time over the ideal time:

r_{bubble} = \frac{(p-1)*(t_f+t_b)}{t_f+t_b} = p-1

As we add more stages the bubble time thus increases and the utilization drops.

Thankfully, various pipeline parallelism schemes have been designed to reduce the size of the bubble which as you can see on this naive example can be very large in a naive implementation.

Let’s take a first tool out of our toolbox and think about splitting our batch into smaller bit-sized portions which can be processed in parallel or almost, like we did before in data parallel for instance. Now when the second GPU is busy processing micro-batch 1, the first GPU can already start processing micro-batch 2. Here is a schedule using 8 micro-batches:

image.png

The above schedule is called the all-forward-all-backward (AFAB) schedule as we first do all forward passes and then only all-backward passes. The advantage is that forward and backward steps are still generally sequential and so preserving the general order of model training. This make this option rather simple to implement.

You can find the full implementation of the AFAB pipeline in picotron: https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/pipeline_parallel/pipeline_parallel.py#L54-L83

Let’s estimate the bubble in this example. The difference with our first example is that the ideal time to process m microbatches is now t_{id} = m*(t_f+t_b):

r_{bubble} = \frac{(p-1)*(t_f+t_b)}{m*(t_f+t_b)} = \frac{p-1}{m}

As we can see, we can fight some inefficiencies of pipeline stages by adding more microbatches, reducing the size of the bubble by a factor of m.

However, as annoying as the bubble is the memory storage required for storing all activation. We need to keep all of the activations in memory until we reach the backward stage which lead to a quick memory explosion in these implementations of PP. Can we do better and avoid this memory explosion?

Since the memory explosion is triggered by the activation we store for the backward pass, let’s try to see if we can start performing the backward pass while we are still performing other forward part of the computation. This will allow us to drop some of the activations we need for the backward pass as soon as possible.

One-forward-one-backward and LLama 3.1 schemes

This schedule is called one-forward-one-backward (1F1B) as the middle/steady state involves alternatively performing one forward and one backward pass. The general idea is to start performing the backward pass as soon as possible. The schedule looks like this:

image.png

The bubble still has the same size so our training efficiency is not significantly improved. However we only need to store activations for p micro-batches instead of m which quite reduce the activation memory explosion we had in the AFAB schedule. As a consequence we can add more microbatches which then will actually reduce the bubble.

image.png

A major complexity of this setup, visible on the above graph is how forward and backward passes are not cleanly consecutive anymore but performed in parallel across devices. This means we will have to schedule the switch from forward to backward passes independently on each device instead of in a simple and common central training loop as usual.

This is one of the reason implementing Pipeline Parallelism usually requires rather extensive modifications to training code as well as modeling code.

Here is the example training loop from the above gist:

You can find the full implementation in picotron as well: https://github.com/huggingface/picotron/blob/0035cce0e04afd6192763b11efe50010d8ad0f71/picotron/pipeline_parallel/pipeline_parallel.py#L85-L145

So reordering a bit the computations helped a lot improving the memory pressure from activations. Could we get even better performance with more intricate schedules? Yes!

Interleaving stages

This schedule has let us improved memory usage but not much the size of the idle buddle. Can we also also reduce the time spent in the bubble?

Well it turns out this is possible if we are willing to bring in a few additional communications. Time to talk about interleaved stages.

Up to now we’ve sliced our model naively along the model depth dimensions, locating for instance layers 1-4 on the first GPU and layers 5-8 on the second GPU. But there are other ways we could think about slicing our layers, e.g. having odd layers 1, 3, 5, 7 on the first GPU and even layers 2, 4, 6, 8 on the second GPU.

This can be seen in general as a kind of “looping pipeline” where a micro-batch will move in circles from one GPU to the next as it goes through the forward pass through the model.

image.png

As a consequence we see additional communications happening as the model goes several times through each GPU for the same computation that previously just took one pass. However, each forward and backward pass is divided by a factor of v, where v is the number of stages or model chunks per GPUs as we are able to better interleave forward and backward passes.

\begin{aligned} &t_{pb} = \frac{(p-1)*(t_f+t_b)}{v} \\ &r_{bubble} = \frac{1}{v}\frac{(p-1)*(t_f+t_b)}{m*(t_f+t_b)} = \frac{p-1}{v*m} \end{aligned}

So we can now decrease the bubble by adding microbatches and interleaved stages, but note that quantitatively, the amount of communication also increases by v so it’s a trade off. In the following plot you can see several configurations for a PP setup with p=8, where the special case of m=1, v=1 corresponds to naive pipeline parallelism and the configurations with v=1 are AFAB or 1F1B setups and v \neq 1 are interleaved configurations.

image.png

Scheduling also becomes more complex here as we need to decide on a GPU whether we are prioritizing at a given moment earlier micro-batches meaning that we close the forward and backward loops as fast as possible (so called “depth-first”, i.e. prioritizing getting batches out of the model as fast as possible) or we prioritize to first complete the forward passes of all microbatches in the queue before going over to backward passes (so called “breadth-first” i.e. prioritizing filling in the pipeline as much as possible). This is explained in detail in the "Breadth-Fist Pipeline" paper.

You now have all the elements to understand the pipeline parallelism approach in Llama 3.1 which is using a one-forward-one-backward setup with interleaved stages and a priority setting tuneable between depth-first and bread-first.

image.png

However, we haven’t reached the end of possible pipeline schedules and recently some methods have been proposed to reduce the bubble to virtually zero! Peaked your curiosity? Let’s have a look!

Zero Bubble and DualPipe

There are even more sophisticated ways to reduce the bubble more and reached close to a “zero bubble” regime. The secret here is to split at an even finer-grained level the operations involved in order to interleave them in the most efficient way. For instance the pipeline implementation approach in DeepSeek V3/R1, called DualPipe reach close to a zero bubble regime.

Let’s very quickly see how this can work by detailing briefly the ZeroBubble work which is a precursor to DualPipe. The base observation of ZeroBubble is that a backward through a matrix multiplication involve actually two separated operations: backward for the inputs (B) and the backward for the weights (W):

image.png

While the output of B, the backward pass for the input, is necessary for performing the backward pass of the lower layers, the backward pass of the weights, W, is not necessary for the rest of the backward pass and generally only need to be performed before the optimiser step. This means W can be flexibly scheduled anywhere after the corresponding B of the same stage. This allows for strategic placement of W to fill the pipeline bubbles. The ZB-H2 schedule on the top right is an example of (theoretical) schedule with zero bubble taking advantage for this fine-grained decomposition.

DeepSeek’s DualPipe introduced with V3 proposes an extension of this decomposed approach to the case of two stream propagating from both sides of the PP ranks and being interleaved to minimize even further idle time in the GPUs are displayed in the following scheduling graph:

image.png

The ZeroBubble and DualPipe schedules are a bit too complex for us to give here code snippets but you should start to have a general idea of the concepts involved. In practice, optimizing these schedules requires careful measurements of the time for each operations followed by a scheduling algorithm able to find the most optimal allocation of time given the constrains. See for instance in the ZeroBubble paper for a discussion of the heuristics and algorithms to perform such a scheduling.

This concludes our tour into the world of pipeline schedules and bubbles. Let's turn to the last parallelism method we can use to train large models efficiently: Expert parallelism.

Expert parallelism

One more thing parallelism.

Mixture-of-expert models have gained some traction with models such as Mixtral or more recently DeepSeek-V3/R1! The basic idea is that instead of having a single feedforward module per layer we can have several and route tokens through different ones depending on their context:

image.png

Source: A Survey on Mixture of Experts

This design makes it very easy to add a new parallelism paradigm: Expert parallelism (EP). Since the feedforward layers are fully independent we can simply put each expert’s feedforward layer on a different worker. Compared to TP it’s much more lightweight, since we don’t need to split the matrix multiplication, we just need to route the hidden states of a token to the right expert. There are several tricks to make EP work in practice, closely tied to model design. For instance, DeepSeek-V3 enforces a constraint in the router, ensuring that each token is sent to at most M nodes (in their case, 4) to reduce communication overhead.

While Expert parallelism has been around for a while it is just now gaining new traction with the MoE architecture gaining more traction.

Congratulation reader, with this brief overview of Expert parallelism you have now seen all 5 parallelism strategies to scale model training:

However, one aspect you are maybe curious right now: how do all these parallelism strategies and ZeRO compare to each other? Let’s look at the similarities and interplay!

5D parallelism in a nutshell

Let’s start with Pipeline parallelism as ZeRO-3 and Pipeline parallelism have interesting similarities and differences.

Both methods are ways to partition the model weights over several GPUs and perform communication/computation along the model depth axis (for example in ZeRO-3, we prefetch the next layer while computing). In the following we say “a layer” to simplify what should be in general called “a set of layer” (as the basis sharding unit of the model). This means in both cases the full layers are computed on device, as opposed to TP, where the layers are sharded for the computation.

However, there are a few major differences between the two:

ZeRO-3 Pipeline parallel
each compute unit only stores a fraction of a layer each compute unit stores a full layer
communication is used to transfer weights communication is used to transfer activations
model agnostic orchestration model agnostic orchestration
Complex implementation to handle model partitioning and communications Complex implementation to handle efficient PP schedules
Prefers large mbs and seq\_len to hide comms Prefers large \text{grad\_acc} to hide bubble

Clearly ZeRO-3 and PP are distinctly different approaches to sharing the model layers and deciding to focus communication either on weights or on activations. While they can be combined, doing so requires increasing the global batch size significantly to amortize the communication costs, creating a tradeoff between global batch size, model size, network bandwidth, and training efficiency. If combined, ZeRO-3 should be configured to keep the weights in memory for each micro-batch in PP to minimize the communication overhead.

Note that ZeRO-1 and ZeRO-2 on the other hand are interesting to combine with Pipeline Parallelism as they focus on gradients and optimizer states instead of parameters and are thus complementary. For instance, DeepSeek-v3 used PP with ZeRO-1!

Tensor Parallelism (with Sequence Parallelism) is naturally complementary and interoperable with both Pipeline Parallelism and ZeRO-3, because it relies on the distributive property of matrix multiplication that allows weights and activations to be sharded and computed independently before being combined. However, TP has two important limitations: First, since its communication operations are part of the critical path of computation, it doesn't scale well beyond a certain point as communication overhead begins to dominate. Second, unlike ZeRO and PP which are model-agnostic, TP requires careful handling of activation sharding - sometimes along the hidden dimension (in the TP region) and sometimes along the sequence dimension (in the SP region) - making it more complex to implement correctly and requiring model-specific knowledge to ensure proper sharding patterns throughout.

image.png

When combining parallelism strategies, TP will typically be kept for high-speed intra-node communications while ZeRO-3 or PP can use parallelism groups spanning lower speed inter-node communications, since their communication patterns are more amenable to scaling. The main consideration is organizing the GPU groups efficiently for each parallelism dimension to maximize throughput and minimize communication overhead, while being mindful of TP's scaling limitations.

Context Parallelism and Expert Parallelism also help us sharding activations, and can be seen as complimentary to TP — The former handles long sequences while the latter enables distributed Mixture of Experts training.

Context Parallelism (CP) specifically targets the challenge of training with very long sequences by sharding activations along the sequence dimension across GPUs. While most operations like MLPs and LayerNorm can process these sharded sequences independently, attention layers require communication since each token needs access to keys/values from the full sequence. This is handled efficiently through ring attention patterns that overlap computation and communication. CP is particularly valuable when scaling to extreme sequence lengths (128k+ tokens) where even with full activation recomputation the memory requirements for attention would be prohibitive on a single GPU.

image.png

Expert Parallelism (EP) specifically targets the challenge of training Mixture of Experts (MoE) models by sharding specialized "experts" across GPUs and dynamically routing tokens to relevant experts during computation. The key communication pattern in EP is the all-to-all operation needed to route tokens to their assigned experts and gather the results back. While this introduces some communication overhead, it enables scaling model capacity significantly since each token only needs to compute through a fraction of the total parameters. This partitioning of experts across GPUs becomes essential when working with models that have a large number of experts, like DeepSeek which uses 256 experts.

It's worth noting the scope of impact for these different parallelism strategies:

Tensor + Sequence Parallel Context Parallel Expert Parallel
shards weights and activations along hidden/seq dim shards activations along sequence dim shards specialized expert weights and activations
communication for matrix multiply operations (column/row linears) communication for attention key/values communication for token routing to experts
model-specific implementation needed model-agnostic except for attention model-agnostic except for MoE layers
Prefers high-bandwidth intra-node communication Prefers large sequence lengths Requires MoEs

Which leads us to this beautiful diagram to summarize all what we’ve seen:

image.png

And to have an idea of the memory benefits of each parallelism:

image.png

How to Find the Best Training Configuration

We’ve now covered all the parallelism techniques that are actually used to distribute and training larger models. There remain a general question: which ones should we choose and which ones are best combined? We touched a little bit on this at the end of the last section but in this section we will walk through the decision process step by step.

First let's have a quick look at each parallel strategy and how it helps and at what cost it comes:

Method Memory savings Parallel/sharding dimension Disadvantage
DP None (replicates everything) Batch Limited by max batch size
ZeRO-1 Optimizer states Batch Params communication overhead
ZeRO-2 Optimizer states and gradients Batch Params communication overhead
ZeRO-3 Optimizer states, gradients, and model parameters Batch and Model Params Params communication overhead
PP Model Model layers Idle bubble and complex schedules
TP/SP Model and activations Hidden dimension / Sequence length Requires high bandwidth communication
CP Activations Sequence length Communication overhead in attention
EP Experts parameters Expert dimension Requires MoE layers, routing overhead

Clearly, there is no free lunch for any of those methods but we can actually come up with a few rules that help finding a good starting point. To find the definitive optimal setup you'll have to run a few experiments in any case.

Step 1: Fitting a Training Step in Memory

First, we need to figure out how we can fit a single model instance on GPUs. There are two general cases.

GPU-rich case 🤑 - when you have plenty of GPUs available:

Special considerations:

GPU-poor case 😭 - when running out of GPU resources:

Now that we have a single model instance training, we need to make sure we have the right batch size.

Step 2: Achieving Target Global Batch Size

Depending on how we setup in step one in terms of micro batch size and DP, our current batch size might be too small or big.

To increase global batch size:

To decrease global batch size:

Ok, now we have the model running in the configuration we want, but is it the fastest way? Let's optimize throughput next.

Step 3: Optimizing Training Throughput

So we want to make sure the training is running as fast as possible so all our precious GPUs are well utilized at all times. As long as memory and communication aren't bottlenecks we can try the following:

We can roughly summarize the journey to the best configuration in the following diagram:

image.png

This concludes our very deep dive into the distribution methods of 5D parallelism. However, besides scaling our model efficiently across GPUs there is another way to improve model throughput and memory management.

Time to turn the lights off and activate CUDA mode!

Diving in the GPUs – fusing, threading, mixing

Up to now our discussion has been focused on the high-level organization of our model operations. We’ve moved around computations on various accelerators, taking into account general memory constraints and high-level scheduling of the compute units.

But this ignored all the optimizations we can do at a much lower level by carefully understanding how our model operations are scheduled and performed on each GPU.

This section will dive into much more details of the GPU architecture and in particular in NVIDIA’s GPU architecture but the general ideas, as often, can be reused on similar accelerator units.

We’ll briefly explain how GPU are organized before covering the Flash-Attention revolution, how to efficiently schedule workload on GPU and finally explain how various precisions can be efficiently used on GPU.

A primer on GPU

Generally, GPUs have a very hierarchical organization. In this primer we’ll keep the discussion at the concept levels that are necessary for the rest of our presentation.

On the compute side, GPUs consist of an array of compute units called Streaming Multiprocessors (SM). Each SM contains and controls a set of streaming processors, also known as cores. For example, an Nvidia H100 GPU has 132 SMs with 128 cores per SM, resulting in a total of 16,896 cores (see docs for tensor cores for details), each capable of handling multiple threads simultaneously.

image.png

How to improve performance with Kernels ?

Memory Coalescing

Tiling

Thread Coarsening

Minimizing Control Divergence

Flash Attention 1-3

Fused Kernels

Mixed Precision Training

FP16 and BF16 training

FP8 pretraining

Conclusion

What you learned

What we learned

What’s next?

References

Landmark LLM Scaling Papers

Training Frameworks

Debugging

Distribution Techniques

CUDA Kernels

Hardware

Others

Appendix

Citation

For attribution in academic contexts, please cite this work as

XXX, et al., "The Ultra-Scale Playbook: Training LLMs on GPU Clusterse", 2025.

BibTeX citation

@misc{TODO,
      title={The Ultra-Scale Playbook: Training LLMs on GPU Clusters},
      author={TODO},
      year={2025},
}