thomwolf HF staff commited on
Commit
cf3c07b
·
verified ·
1 Parent(s): 64d9f80
assets/images/sign-mantissa-exponent.svg ADDED
dist/assets/images/sign-mantissa-exponent.svg ADDED
dist/index.html CHANGED
@@ -345,7 +345,7 @@
345
  </p></div>
346
  </div>
347
 
348
- <p>These items are stored as tensors which come in different <em>shapes</em> and <em>precisions</em>. The <em>shapes</em> are determined by hyper-parameters such as batch size, sequence length, model hidden dimensions, attention heads, vocabulary size, and potential model sharding as we’ll see later. <em>Precision</em> refers to formats like FP32, BF16, or FP8, which respectively require 4, 2, or 1 byte to store each single value in the tensor.</p>
349
 
350
  <p>So how can I quickly determine memory usage from these variable? One simple way is to do this empirically and just measure it.</p>
351
 
@@ -815,7 +815,7 @@
815
 
816
  <h4>Memory usage revisited</h4>
817
 
818
- <p>You likely remember from <a target="_self" href="#memory_usage_in_transformers"> our previous section</a> the memory usage of optimizer states, gradients, and parameters during a standard training. Lets call our model's parameters count <d-math>\Psi</d-math> (previously N but here we use the original ZeRO paper notation). In mixed-precision training with the Adam optimizer, the memory usage for each item we need to store is:</p>
819
 
820
  <ul>
821
  <li>Model’s parameters (half precision i.e. bf16/fp16): <d-math>2\Psi</d-math></li>
@@ -2274,30 +2274,36 @@
2274
 
2275
  <p>In several places now we’ve mentioned how GPU and CPU operation can be asynchronous. In particular, the host code on the CPU can schedule workload on the GPU in a non-blocking way.</p>
2276
 
2277
- <p>Non-blocking can be useful for overlapping communication and computation as we saw at several part along this blog post but can be extended to the more general idea of trying to avoid at all cost going back and forth between host and GPU kernel commands. This is beautifully illustrated by <a href="https://horace.io/brrr_intro.html">Horace He</a> in these diagrams:</p>
 
2278
  <div style="display: flex; gap: 20px; align-items: flex-start;">
2279
  <div style="width: 50%;">
2280
  <img alt="image.png" src="/assets/images/fused_kernels1.png" style="width: 100%;" />
 
2281
  <p>A sequence of kernels requiring back and forth between global memory and compute units</p>
2282
  </div>
 
2283
  <div style="width: 50%;">
2284
  <img alt="image.png" src="/assets/images/fused_kernels2.png" style="width: 100%;" />
2285
- <p>Instead of sending our triangle back to global memory just to read it back again, we instead just do all of our operations in one go.</p>
 
 
2286
  </div>
2287
  </div>
2288
 
2289
  <p>How can we avoid this back and forth? Well the best way is to make our GPU as autonomous as possible. This is achieved by packing as many successive compute operations together in a single kernel for the GPU to run, called a “Fused Kernel”.</p>
2290
 
2291
 
2292
- <p>Fused kernel are especially efficient and simple to write for succession of point-like operations which are performed independently of each other on each input tokens. In this case, there is no point in bringing back computed values in Global Memory before moving them to SM memory and spinning up a new kernel. It’s much more efficient to keep all values local until the succession of computation has been performed.</p>
2293
 
2294
- <p>What are many places in a Transformer model were this can be advantageous, for instance when. a succession of point-wise operations is performed, e.g. in the computation involved in the Layer norms.</p>
2295
 
2296
  <p>We now have all the understanding necessary to marvel at a true masterpiece of kernel engineering: <strong><em>Flash Attention</em></strong></p>
2297
 
2298
  <h3>Flash Attention 1-3</h3>
2299
 
2300
- <p>Flash attention is a technique pioneered by <a href="https://tridao.me">Tri Dao</a> that optimizes the attention computations by writing custom CUDA kernels to make it much faster *and* more memory efficient. The idea behind Flash Attention is to make efficient use of the various memories of the GPU to avoid using too much the slowest global memory of the GPU (confusingly called the High Bandwidth Memory, HBM 🫠)</p>
 
2301
 
2302
  <p>A basic implementation of the attention mechanism involve a lot of transfer between memory and workers. It requires materializing the S and P matrices in HBM which means that the results need to be sent to HBM and then back to SRAM for the next computations:</p>
2303
 
@@ -2324,9 +2330,14 @@
2324
 
2325
  <p>Flash-Attention is a master demonstration of the breakthrough improvements that can come when you take into account the internal memory/compute design of current GPU accelerators.</p>
2326
 
2327
- <p>The techniques described so far in this section require specific modeling code changes and writing custom kernels for certain operations in order to speed up training. In this section we take a look at a range of methods that are agnostic to the modeling code and can be used for any model!</p>
 
 
 
2328
 
2329
  <h3>Mixed Precision Training</h3>
 
 
2330
 
2331
  <p>Mixed Precision Training, as the name suggests, involves mixing different precisions when training. The default numerical precision of PyTorch tensors is single-precision floating point format or also called FP32 or float32 which means that every number stored takes up 32 bits or 4 bytes. The available bits to represent a number are divided into 3 parts:</p>
2332
 
@@ -2336,6 +2347,10 @@
2336
  <li>Exponent: controls the magnitude of the number</li>
2337
  </ul>
2338
 
 
 
 
 
2339
  <p>The principle of floating point numbers can be easily illustrated by recalling the scientific notation of numbers, e.g. <d-math>- 5.734 \times 10^{7}</d-math>, where we first have the sign, followed by the mantissa an the exponent. As such we can represent numbers across a wide range of magnitudes with an adaptive precision. Although float32 is the default there is a range of floating point formats available in PyTorch:</p>
2340
 
2341
  <p></p>
@@ -2346,8 +2361,8 @@
2346
  <th><strong>Format</strong></th>
2347
  <th><strong>Total bits</strong></th>
2348
  <th><strong>Sign</strong></th>
2349
- <th><strong>Mantissa</strong></th>
2350
  <th><strong>Exponent</strong></th>
 
2351
  </tr>
2352
  </thead>
2353
  <tbody>
@@ -2355,36 +2370,36 @@
2355
  <td>float32</td>
2356
  <td>32</td>
2357
  <td>1</td>
2358
- <td>23</td>
2359
  <td>8</td>
 
2360
  </tr>
2361
  <tr>
2362
  <td>float16</td>
2363
  <td>16</td>
2364
  <td>1</td>
2365
- <td>10</td>
2366
  <td>5</td>
 
2367
  </tr>
2368
  <tr>
2369
  <td>bfloat16</td>
2370
  <td>16</td>
2371
  <td>1</td>
2372
- <td>7</td>
2373
  <td>8</td>
 
2374
  </tr>
2375
  <tr>
2376
  <td>float8 (e4m3)</td>
2377
  <td>8</td>
2378
  <td>1</td>
2379
- <td>3</td>
2380
  <td>4</td>
 
2381
  </tr>
2382
  <tr>
2383
  <td>float8 (e5m2)</td>
2384
  <td>8</td>
2385
  <td>1</td>
2386
- <td>2</td>
2387
  <td>5</td>
 
2388
  </tr>
2389
  </tbody>
2390
  </table>
@@ -2404,11 +2419,11 @@
2404
 
2405
  <p>We can see here that bfloat16 maintained the range of float32 over float16 but did this with the cost of sacrificing more precision. In case of float8 the situation is even more dire as e4m3 can represent 7 and e5m2 only 3 number on the interval 1-2.</p>
2406
 
2407
- <p>A common metric to measure a formats resolution is epsilon: the first representable number after 1.00. We can see that for the float32 format $10^{-4}$ is an upper bound (it’s actually <d-math>1.19^{-7}</d-math>). For float16 it is <d-math>\tilde 10^{-3}</d-math> and for bfloat 10x higher still.</p>
2408
 
2409
- <p>The idea of mixed precision training is to use some of these lower precisions formats while maintaining the performance of full precision training. It turns out we <strong>can’t</strong> totally abandon float32 and usually will need to maintain some parts in full precision.</p>
2410
 
2411
- <p>This is why lower precision training is usually called <strong><em>mixed precision</em></strong> training. </p>
2412
 
2413
  <p>Let’s now take a look at training models with 16 bits and then see if we can take it a step further all the way down to 8 bits.</p>
2414
 
@@ -2421,10 +2436,10 @@
2421
  <ol>
2422
  <li><strong>FP32 copy of weights</strong>: There are two possible issues with float16 weights. During training some of the weights can become very small and will be rounded to 0. However, even if the weights themselves are not close to zero, if the updates are very small the difference in magnitude can cause the weights to underflow during the addition. Once the weights are zero they will remain 0 for the rest of training as there is no gradient signal coming through anymore.</li>
2423
  <li><strong>Loss scaling</strong>: We have a similar issue with the gradients as well as gradients tend to be much smaller than 1 and are thus at risk to underflow. A simple, yet effective, strategy is to scale the loss before the backward pass and unscale the gradients after the backward pass. This ensures that there is no underflow during the backward pass and the scaling is not affecting training as we unscale before processing the gradients further (e.g. clipping) and the optimization step. </li>
2424
- <li><strong>Accumulation</strong>: Finally, when performing arithmetic operations in float16 such as in dot products, we can also face under or overflows. Does targeting certain types of arithmetic operations to accumulate the intermediate results in float32 during the operation and then casting the accumulated result back to fp16. For the same reason gradients are also accumulated in float32.</li>
2425
  </ol>
2426
 
2427
- <p>With these techniques, you get consistently stable training while benefitting from higher throughput due to the faster, lower precision operations. Naturally, as the curious reader you are and by now slightly addicted to maximizing the throughput, you ask the question: can we go further and faster? </p>
2428
 
2429
  <p>Maybe!</p>
2430
 
@@ -2436,6 +2451,7 @@
2436
 
2437
  <p>We know that instability increases as learning rates rise for a fixed model size<d-cite bibtex-key="wortsman2023smallscaleproxieslargescaletransformer"></d-cite>, making FP8 pretraining particularly tricky.</p>
2438
 
 
2439
  <iframe class="l-body-outset" id="plotFP8Loss" src="/assets/data/fp8/fp8_training_loss_curves.html" height="520" width="1000" scrolling="no" frameborder="0"></iframe>
2440
  <!-- Hynek uncomment this once it's added to -->
2441
  <!-- <div class="l-body-outset" id="fragment-fp8_training_loss_curves"></div> -->
@@ -2444,7 +2460,7 @@
2444
 
2445
  <p><img alt="image.png" src="/assets/images/fp8_diagram.png" /></p>
2446
 
2447
- <p>In order to switch from high precision (e.g. FP32 or BF16) to lower precision (e.g. FP16 or FP8) with smaller range, we need to normalize the range of values by computing the absolute maximum. DeepSeek-V3 also introduces a quantization scheme, where the ranges are normalized per tile: 1x128 for inputs/activations and 128x128 for weights and scale elements. This makes the normalization less susceptible to outliers. There is a number of additional tricks they deploy to also reduce the memory and communication footprint which you can follow in section 3.3. of the DeepSeek-V3 technical report<d-cite bibtex-key="deepseekai2024deepseekv3technicalreport"></d-cite>. </p>
2448
 
2449
  <p>Here’s a summary of a few known approaches to FP8 training:</p>
2450
 
@@ -2525,11 +2541,13 @@
2525
  </tbody>
2526
  </table>
2527
 
2528
- <p>Overall, FP8 is still an experimental technique and methods are evolving, but will likely become the standard soon replacing bf16 mixed-precision. To follow a public implementations of this, please head to the nanotron’s implementation in <a href="https://github.com/huggingface/nanotron/pull/70">this PR</a>. </p>
2529
 
2530
- <p>In the future, Blackwell, the next generation of NVIDIA chips, <a href="https://www.nvidia.com/en-us/data-center/technologies/blackwell-architecture/">have been announced </a> to support FP4 training, further speeding up training but without a doubt also introducing a new training stability challenge.</p>
 
 
2531
 
2532
- <p>We now arrived at the end of the distributed training journey. Let’s take a step back and conclude.</p>
2533
 
2534
  <h2>Conclusion</h2>
2535
 
 
345
  </p></div>
346
  </div>
347
 
348
+ <p>These items are stored as tensors which come in different <em>shapes</em> and <em>precisions</em>. The <em>shapes</em> are determined by hyper-parameters such as batch size, sequence length, model hidden dimensions, attention heads, vocabulary size, and potential model sharding as we’ll see later. <em>Precision</em> refers to formats like FP32, BF16, or FP8, which respectively require 4, 2, or 1 byte to store each single value in the tensor. We will have a full discussion of the different precisions and their trade-offs in the <a target="_self" href="#mixed_precision_training">Mixed Precision Training</a> section, for now let's just keep in mind that the memory requirements for these various format will be different and that will impact the memory usage of the items we need to store.</p>
349
 
350
  <p>So how can I quickly determine memory usage from these variable? One simple way is to do this empirically and just measure it.</p>
351
 
 
815
 
816
  <h4>Memory usage revisited</h4>
817
 
818
+ <p>You likely remember from <a target="_self" href="#memory_usage_in_transformers"> our previous section</a> the memory usage of optimizer states, gradients, and parameters during a standard training. Lets call our model's parameters count <d-math>\Psi</d-math> (previously N but here we use the original ZeRO paper notation). In <a target="_self" href="#mixed_precision_training">Mixed Precision Training</a> (more details in a later section) with the Adam optimizer, the memory usage for each item we need to store is:</p>
819
 
820
  <ul>
821
  <li>Model’s parameters (half precision i.e. bf16/fp16): <d-math>2\Psi</d-math></li>
 
2274
 
2275
  <p>In several places now we’ve mentioned how GPU and CPU operation can be asynchronous. In particular, the host code on the CPU can schedule workload on the GPU in a non-blocking way.</p>
2276
 
2277
+ <p>Non-blocking can be useful for overlapping communication and computation as we saw many times along our journey– but can be extended to the more general idea of trying to avoid at all cost going back and forth between host and GPU kernel commands.</p>
2278
+ <p>This idea is beautifully illustrated by <a href="https://horace.io/brrr_intro.html">Horace He</a> in these diagrams:</p>
2279
  <div style="display: flex; gap: 20px; align-items: flex-start;">
2280
  <div style="width: 50%;">
2281
  <img alt="image.png" src="/assets/images/fused_kernels1.png" style="width: 100%;" />
2282
+ <div class="figure-legend">
2283
  <p>A sequence of kernels requiring back and forth between global memory and compute units</p>
2284
  </div>
2285
+ </div>
2286
  <div style="width: 50%;">
2287
  <img alt="image.png" src="/assets/images/fused_kernels2.png" style="width: 100%;" />
2288
+ <div class="figure-legend">
2289
+ <p>Instead of sending our triangle back to global memory just to read it back again, we instead just do all of our operations in one go.</p>
2290
+ </div>
2291
  </div>
2292
  </div>
2293
 
2294
  <p>How can we avoid this back and forth? Well the best way is to make our GPU as autonomous as possible. This is achieved by packing as many successive compute operations together in a single kernel for the GPU to run, called a “Fused Kernel”.</p>
2295
 
2296
 
2297
+ <p>Fused kernel are especially efficient and simple to write for succession of point-like operations which are performed independently of each other on each input tokens. In this case, there is no point in bringing back computed values in Global Memory before moving them to SM memory and spinning up a new kernel. It’s much more efficient to keep all values locally until the succession of computation has been performed.</p>
2298
 
2299
+ <p>There are many places in a Transformer model where this "fusing" approach can be applied: every time we have a succession of point-wise operations e.g. in the computation involved in the Layer norms.</p>
2300
 
2301
  <p>We now have all the understanding necessary to marvel at a true masterpiece of kernel engineering: <strong><em>Flash Attention</em></strong></p>
2302
 
2303
  <h3>Flash Attention 1-3</h3>
2304
 
2305
+ <p>Flash attention was introduced by <a href="https://tridao.me">Tri Dao</a> and proposed to optimize the attention computations by writing custom CUDA kernels make them much faster *and* more memory efficient. The idea behind Flash Attention is to make efficient use of the various memories of the GPU to avoid relying too much on the slowest one: the global memory of the GPU.</p>
2306
+ <aside>Note that the global memory of the GPU is confusingly called the "High Bandwidth Memory", HBM 🫠</aside>
2307
 
2308
  <p>A basic implementation of the attention mechanism involve a lot of transfer between memory and workers. It requires materializing the S and P matrices in HBM which means that the results need to be sent to HBM and then back to SRAM for the next computations:</p>
2309
 
 
2330
 
2331
  <p>Flash-Attention is a master demonstration of the breakthrough improvements that can come when you take into account the internal memory/compute design of current GPU accelerators.</p>
2332
 
2333
+ <hr>
2334
+
2335
+ <p>The techniques described so far in this operation-fusion section have required us to implement modeling code changes and write custom kernels for certain operations in order to speed up training.</p>
2336
+ <p>In the final section of our low-level dive in the compute operations themselves, we will take a look at a range of methods that are agnostic to the modeling code and can be used for any model and are so widely used that they have become a standard in the industry: <strong>Mixed Precision Training</strong>!</p>
2337
 
2338
  <h3>Mixed Precision Training</h3>
2339
+
2340
+ <p>In various sections along this book, we've talked about lower precisions formats and their impact on the memory requirements for storing activations, parameters and optimizer states. It's now time to dive deeper in the details of these formats and understand better their trade-offs, advantages and limitations.</p>
2341
 
2342
  <p>Mixed Precision Training, as the name suggests, involves mixing different precisions when training. The default numerical precision of PyTorch tensors is single-precision floating point format or also called FP32 or float32 which means that every number stored takes up 32 bits or 4 bytes. The available bits to represent a number are divided into 3 parts:</p>
2343
 
 
2347
  <li>Exponent: controls the magnitude of the number</li>
2348
  </ul>
2349
 
2350
+ <p><img width="500px" alt="sign-mantissa-exponent.svg" src="/assets/images/sign-mantissa-exponent.svg" /></p>
2351
+
2352
+
2353
+
2354
  <p>The principle of floating point numbers can be easily illustrated by recalling the scientific notation of numbers, e.g. <d-math>- 5.734 \times 10^{7}</d-math>, where we first have the sign, followed by the mantissa an the exponent. As such we can represent numbers across a wide range of magnitudes with an adaptive precision. Although float32 is the default there is a range of floating point formats available in PyTorch:</p>
2355
 
2356
  <p></p>
 
2361
  <th><strong>Format</strong></th>
2362
  <th><strong>Total bits</strong></th>
2363
  <th><strong>Sign</strong></th>
 
2364
  <th><strong>Exponent</strong></th>
2365
+ <th><strong>Mantissa</strong></th>
2366
  </tr>
2367
  </thead>
2368
  <tbody>
 
2370
  <td>float32</td>
2371
  <td>32</td>
2372
  <td>1</td>
 
2373
  <td>8</td>
2374
+ <td>23</td>
2375
  </tr>
2376
  <tr>
2377
  <td>float16</td>
2378
  <td>16</td>
2379
  <td>1</td>
 
2380
  <td>5</td>
2381
+ <td>10</td>
2382
  </tr>
2383
  <tr>
2384
  <td>bfloat16</td>
2385
  <td>16</td>
2386
  <td>1</td>
 
2387
  <td>8</td>
2388
+ <td>7</td>
2389
  </tr>
2390
  <tr>
2391
  <td>float8 (e4m3)</td>
2392
  <td>8</td>
2393
  <td>1</td>
 
2394
  <td>4</td>
2395
+ <td>3</td>
2396
  </tr>
2397
  <tr>
2398
  <td>float8 (e5m2)</td>
2399
  <td>8</td>
2400
  <td>1</td>
 
2401
  <td>5</td>
2402
+ <td>2</td>
2403
  </tr>
2404
  </tbody>
2405
  </table>
 
2419
 
2420
  <p>We can see here that bfloat16 maintained the range of float32 over float16 but did this with the cost of sacrificing more precision. In case of float8 the situation is even more dire as e4m3 can represent 7 and e5m2 only 3 number on the interval 1-2.</p>
2421
 
2422
+ <p>A common metric to measure a formats resolution is epsilon: the first representable number after <d-math>1.00</d-math>. We can see that for the float32 format <d-math>10^{-4}</d-math> is an upper bound (it’s actually <d-math>1.19^{-7}</d-math>). For float16 it is <d-math>\tilde 10^{-3}</d-math> and for bfloat 10x higher still.</p>
2423
 
2424
+ <p>The idea of mixed precision training is to use some of these lower precisions formats while maintaining the performance of full precision training. </p>
2425
 
2426
+ <p>It turns out we <strong>can’t</strong> totally abandon float32 and usually will need to maintain some parts in full precision. This is why lower precision training is usually called <strong><em>mixed precision</em></strong> training. </p>
2427
 
2428
  <p>Let’s now take a look at training models with 16 bits and then see if we can take it a step further all the way down to 8 bits.</p>
2429
 
 
2436
  <ol>
2437
  <li><strong>FP32 copy of weights</strong>: There are two possible issues with float16 weights. During training some of the weights can become very small and will be rounded to 0. However, even if the weights themselves are not close to zero, if the updates are very small the difference in magnitude can cause the weights to underflow during the addition. Once the weights are zero they will remain 0 for the rest of training as there is no gradient signal coming through anymore.</li>
2438
  <li><strong>Loss scaling</strong>: We have a similar issue with the gradients as well as gradients tend to be much smaller than 1 and are thus at risk to underflow. A simple, yet effective, strategy is to scale the loss before the backward pass and unscale the gradients after the backward pass. This ensures that there is no underflow during the backward pass and the scaling is not affecting training as we unscale before processing the gradients further (e.g. clipping) and the optimization step. </li>
2439
+ <li><strong>Accumulation</strong>: Finally, when performing certain arithmetic operations in 16-bit precision such as averages or summations, we can also face under or overflows. A solution is then to accumulate intermediate results in float32 during the operation and only cast the final result back to 16 bit precision.</li>
2440
  </ol>
2441
 
2442
+ <p>With these techniques, we can get a stable training while benefitting from a higher throughput due to the faster, lower precision arithmetic operations. Naturally, as a curious reader and by now slightly addicted to maximizing the throughput you may ask the question: can we go further and faster than 16-bit precision? </p>
2443
 
2444
  <p>Maybe!</p>
2445
 
 
2451
 
2452
  <p>We know that instability increases as learning rates rise for a fixed model size<d-cite bibtex-key="wortsman2023smallscaleproxieslargescaletransformer"></d-cite>, making FP8 pretraining particularly tricky.</p>
2453
 
2454
+ <p>Here is an example of a typically divergent loss curve for FP8 training:</p>
2455
  <iframe class="l-body-outset" id="plotFP8Loss" src="/assets/data/fp8/fp8_training_loss_curves.html" height="520" width="1000" scrolling="no" frameborder="0"></iframe>
2456
  <!-- Hynek uncomment this once it's added to -->
2457
  <!-- <div class="l-body-outset" id="fragment-fp8_training_loss_curves"></div> -->
 
2460
 
2461
  <p><img alt="image.png" src="/assets/images/fp8_diagram.png" /></p>
2462
 
2463
+ <p>In order to switch from high precision (e.g. FP32 or BF16) to lower precision (e.g. FP16 or FP8) with smaller range, we need to normalize the range of activation values, for instance by computing their absolute maximum. DeepSeek-V3 further introduced a specific quantization scheme where the ranges are normalized per tile: 1x128 for inputs/activations and 128x128 for weights and scale elements. This makes the normalization less strongly impacted by outlier values in the activations. There is a number of additional tricks they proposed to further reduce the memory and communication footprint which you can follow in section 3.3. of the DeepSeek-V3 technical report<d-cite bibtex-key="deepseekai2024deepseekv3technicalreport"></d-cite>. </p>
2464
 
2465
  <p>Here’s a summary of a few known approaches to FP8 training:</p>
2466
 
 
2541
  </tbody>
2542
  </table>
2543
 
2544
+ <p>Overall, FP8 remains –in early 2025– an experimental technique and methods are still evolving. Given its obvious benefits, it will likely become the standard and soon replace bf16 mixed-precision. To follow an open-source implementations of FP8 training techniques, please head to the nanotron’s implementation in <a href="https://github.com/huggingface/nanotron/pull/70">this PR</a>. </p>
2545
 
2546
+ <p>Projecting further into the future, Blackwell, the next generation of NVIDIA chips, <a href="https://www.nvidia.com/en-us/data-center/technologies/blackwell-architecture/">have been announced </a> to support FP4 training, further speeding up training but without a doubt also introducing a new training stability challenge.</p>
2547
+
2548
+ <hr>
2549
 
2550
+ <p>This last section concluded our long journey in the land of fast and large model training on tens to thousands of GPUs. Time to slowly bring our GPU cluster to rest and take a step back to conclude on all we've learned along the way.</p>
2551
 
2552
  <h2>Conclusion</h2>
2553
 
src/index.html CHANGED
@@ -345,7 +345,7 @@
345
  </p></div>
346
  </div>
347
 
348
- <p>These items are stored as tensors which come in different <em>shapes</em> and <em>precisions</em>. The <em>shapes</em> are determined by hyper-parameters such as batch size, sequence length, model hidden dimensions, attention heads, vocabulary size, and potential model sharding as we’ll see later. <em>Precision</em> refers to formats like FP32, BF16, or FP8, which respectively require 4, 2, or 1 byte to store each single value in the tensor.</p>
349
 
350
  <p>So how can I quickly determine memory usage from these variable? One simple way is to do this empirically and just measure it.</p>
351
 
@@ -815,7 +815,7 @@
815
 
816
  <h4>Memory usage revisited</h4>
817
 
818
- <p>You likely remember from <a target="_self" href="#memory_usage_in_transformers"> our previous section</a> the memory usage of optimizer states, gradients, and parameters during a standard training. Lets call our model's parameters count <d-math>\Psi</d-math> (previously N but here we use the original ZeRO paper notation). In mixed-precision training with the Adam optimizer, the memory usage for each item we need to store is:</p>
819
 
820
  <ul>
821
  <li>Model’s parameters (half precision i.e. bf16/fp16): <d-math>2\Psi</d-math></li>
@@ -2274,30 +2274,36 @@
2274
 
2275
  <p>In several places now we’ve mentioned how GPU and CPU operation can be asynchronous. In particular, the host code on the CPU can schedule workload on the GPU in a non-blocking way.</p>
2276
 
2277
- <p>Non-blocking can be useful for overlapping communication and computation as we saw at several part along this blog post but can be extended to the more general idea of trying to avoid at all cost going back and forth between host and GPU kernel commands. This is beautifully illustrated by <a href="https://horace.io/brrr_intro.html">Horace He</a> in these diagrams:</p>
 
2278
  <div style="display: flex; gap: 20px; align-items: flex-start;">
2279
  <div style="width: 50%;">
2280
  <img alt="image.png" src="/assets/images/fused_kernels1.png" style="width: 100%;" />
 
2281
  <p>A sequence of kernels requiring back and forth between global memory and compute units</p>
2282
  </div>
 
2283
  <div style="width: 50%;">
2284
  <img alt="image.png" src="/assets/images/fused_kernels2.png" style="width: 100%;" />
2285
- <p>Instead of sending our triangle back to global memory just to read it back again, we instead just do all of our operations in one go.</p>
 
 
2286
  </div>
2287
  </div>
2288
 
2289
  <p>How can we avoid this back and forth? Well the best way is to make our GPU as autonomous as possible. This is achieved by packing as many successive compute operations together in a single kernel for the GPU to run, called a “Fused Kernel”.</p>
2290
 
2291
 
2292
- <p>Fused kernel are especially efficient and simple to write for succession of point-like operations which are performed independently of each other on each input tokens. In this case, there is no point in bringing back computed values in Global Memory before moving them to SM memory and spinning up a new kernel. It’s much more efficient to keep all values local until the succession of computation has been performed.</p>
2293
 
2294
- <p>What are many places in a Transformer model were this can be advantageous, for instance when. a succession of point-wise operations is performed, e.g. in the computation involved in the Layer norms.</p>
2295
 
2296
  <p>We now have all the understanding necessary to marvel at a true masterpiece of kernel engineering: <strong><em>Flash Attention</em></strong></p>
2297
 
2298
  <h3>Flash Attention 1-3</h3>
2299
 
2300
- <p>Flash attention is a technique pioneered by <a href="https://tridao.me">Tri Dao</a> that optimizes the attention computations by writing custom CUDA kernels to make it much faster *and* more memory efficient. The idea behind Flash Attention is to make efficient use of the various memories of the GPU to avoid using too much the slowest global memory of the GPU (confusingly called the High Bandwidth Memory, HBM 🫠)</p>
 
2301
 
2302
  <p>A basic implementation of the attention mechanism involve a lot of transfer between memory and workers. It requires materializing the S and P matrices in HBM which means that the results need to be sent to HBM and then back to SRAM for the next computations:</p>
2303
 
@@ -2324,9 +2330,14 @@
2324
 
2325
  <p>Flash-Attention is a master demonstration of the breakthrough improvements that can come when you take into account the internal memory/compute design of current GPU accelerators.</p>
2326
 
2327
- <p>The techniques described so far in this section require specific modeling code changes and writing custom kernels for certain operations in order to speed up training. In this section we take a look at a range of methods that are agnostic to the modeling code and can be used for any model!</p>
 
 
 
2328
 
2329
  <h3>Mixed Precision Training</h3>
 
 
2330
 
2331
  <p>Mixed Precision Training, as the name suggests, involves mixing different precisions when training. The default numerical precision of PyTorch tensors is single-precision floating point format or also called FP32 or float32 which means that every number stored takes up 32 bits or 4 bytes. The available bits to represent a number are divided into 3 parts:</p>
2332
 
@@ -2336,6 +2347,10 @@
2336
  <li>Exponent: controls the magnitude of the number</li>
2337
  </ul>
2338
 
 
 
 
 
2339
  <p>The principle of floating point numbers can be easily illustrated by recalling the scientific notation of numbers, e.g. <d-math>- 5.734 \times 10^{7}</d-math>, where we first have the sign, followed by the mantissa an the exponent. As such we can represent numbers across a wide range of magnitudes with an adaptive precision. Although float32 is the default there is a range of floating point formats available in PyTorch:</p>
2340
 
2341
  <p></p>
@@ -2346,8 +2361,8 @@
2346
  <th><strong>Format</strong></th>
2347
  <th><strong>Total bits</strong></th>
2348
  <th><strong>Sign</strong></th>
2349
- <th><strong>Mantissa</strong></th>
2350
  <th><strong>Exponent</strong></th>
 
2351
  </tr>
2352
  </thead>
2353
  <tbody>
@@ -2355,36 +2370,36 @@
2355
  <td>float32</td>
2356
  <td>32</td>
2357
  <td>1</td>
2358
- <td>23</td>
2359
  <td>8</td>
 
2360
  </tr>
2361
  <tr>
2362
  <td>float16</td>
2363
  <td>16</td>
2364
  <td>1</td>
2365
- <td>10</td>
2366
  <td>5</td>
 
2367
  </tr>
2368
  <tr>
2369
  <td>bfloat16</td>
2370
  <td>16</td>
2371
  <td>1</td>
2372
- <td>7</td>
2373
  <td>8</td>
 
2374
  </tr>
2375
  <tr>
2376
  <td>float8 (e4m3)</td>
2377
  <td>8</td>
2378
  <td>1</td>
2379
- <td>3</td>
2380
  <td>4</td>
 
2381
  </tr>
2382
  <tr>
2383
  <td>float8 (e5m2)</td>
2384
  <td>8</td>
2385
  <td>1</td>
2386
- <td>2</td>
2387
  <td>5</td>
 
2388
  </tr>
2389
  </tbody>
2390
  </table>
@@ -2404,11 +2419,11 @@
2404
 
2405
  <p>We can see here that bfloat16 maintained the range of float32 over float16 but did this with the cost of sacrificing more precision. In case of float8 the situation is even more dire as e4m3 can represent 7 and e5m2 only 3 number on the interval 1-2.</p>
2406
 
2407
- <p>A common metric to measure a formats resolution is epsilon: the first representable number after 1.00. We can see that for the float32 format $10^{-4}$ is an upper bound (it’s actually <d-math>1.19^{-7}</d-math>). For float16 it is <d-math>\tilde 10^{-3}</d-math> and for bfloat 10x higher still.</p>
2408
 
2409
- <p>The idea of mixed precision training is to use some of these lower precisions formats while maintaining the performance of full precision training. It turns out we <strong>can’t</strong> totally abandon float32 and usually will need to maintain some parts in full precision.</p>
2410
 
2411
- <p>This is why lower precision training is usually called <strong><em>mixed precision</em></strong> training. </p>
2412
 
2413
  <p>Let’s now take a look at training models with 16 bits and then see if we can take it a step further all the way down to 8 bits.</p>
2414
 
@@ -2421,10 +2436,10 @@
2421
  <ol>
2422
  <li><strong>FP32 copy of weights</strong>: There are two possible issues with float16 weights. During training some of the weights can become very small and will be rounded to 0. However, even if the weights themselves are not close to zero, if the updates are very small the difference in magnitude can cause the weights to underflow during the addition. Once the weights are zero they will remain 0 for the rest of training as there is no gradient signal coming through anymore.</li>
2423
  <li><strong>Loss scaling</strong>: We have a similar issue with the gradients as well as gradients tend to be much smaller than 1 and are thus at risk to underflow. A simple, yet effective, strategy is to scale the loss before the backward pass and unscale the gradients after the backward pass. This ensures that there is no underflow during the backward pass and the scaling is not affecting training as we unscale before processing the gradients further (e.g. clipping) and the optimization step. </li>
2424
- <li><strong>Accumulation</strong>: Finally, when performing arithmetic operations in float16 such as in dot products, we can also face under or overflows. Does targeting certain types of arithmetic operations to accumulate the intermediate results in float32 during the operation and then casting the accumulated result back to fp16. For the same reason gradients are also accumulated in float32.</li>
2425
  </ol>
2426
 
2427
- <p>With these techniques, you get consistently stable training while benefitting from higher throughput due to the faster, lower precision operations. Naturally, as the curious reader you are and by now slightly addicted to maximizing the throughput, you ask the question: can we go further and faster? </p>
2428
 
2429
  <p>Maybe!</p>
2430
 
@@ -2436,6 +2451,7 @@
2436
 
2437
  <p>We know that instability increases as learning rates rise for a fixed model size<d-cite bibtex-key="wortsman2023smallscaleproxieslargescaletransformer"></d-cite>, making FP8 pretraining particularly tricky.</p>
2438
 
 
2439
  <iframe class="l-body-outset" id="plotFP8Loss" src="/assets/data/fp8/fp8_training_loss_curves.html" height="520" width="1000" scrolling="no" frameborder="0"></iframe>
2440
  <!-- Hynek uncomment this once it's added to -->
2441
  <!-- <div class="l-body-outset" id="fragment-fp8_training_loss_curves"></div> -->
@@ -2444,7 +2460,7 @@
2444
 
2445
  <p><img alt="image.png" src="/assets/images/fp8_diagram.png" /></p>
2446
 
2447
- <p>In order to switch from high precision (e.g. FP32 or BF16) to lower precision (e.g. FP16 or FP8) with smaller range, we need to normalize the range of values by computing the absolute maximum. DeepSeek-V3 also introduces a quantization scheme, where the ranges are normalized per tile: 1x128 for inputs/activations and 128x128 for weights and scale elements. This makes the normalization less susceptible to outliers. There is a number of additional tricks they deploy to also reduce the memory and communication footprint which you can follow in section 3.3. of the DeepSeek-V3 technical report<d-cite bibtex-key="deepseekai2024deepseekv3technicalreport"></d-cite>. </p>
2448
 
2449
  <p>Here’s a summary of a few known approaches to FP8 training:</p>
2450
 
@@ -2525,11 +2541,13 @@
2525
  </tbody>
2526
  </table>
2527
 
2528
- <p>Overall, FP8 is still an experimental technique and methods are evolving, but will likely become the standard soon replacing bf16 mixed-precision. To follow a public implementations of this, please head to the nanotron’s implementation in <a href="https://github.com/huggingface/nanotron/pull/70">this PR</a>. </p>
2529
 
2530
- <p>In the future, Blackwell, the next generation of NVIDIA chips, <a href="https://www.nvidia.com/en-us/data-center/technologies/blackwell-architecture/">have been announced </a> to support FP4 training, further speeding up training but without a doubt also introducing a new training stability challenge.</p>
 
 
2531
 
2532
- <p>We now arrived at the end of the distributed training journey. Let’s take a step back and conclude.</p>
2533
 
2534
  <h2>Conclusion</h2>
2535
 
 
345
  </p></div>
346
  </div>
347
 
348
+ <p>These items are stored as tensors which come in different <em>shapes</em> and <em>precisions</em>. The <em>shapes</em> are determined by hyper-parameters such as batch size, sequence length, model hidden dimensions, attention heads, vocabulary size, and potential model sharding as we’ll see later. <em>Precision</em> refers to formats like FP32, BF16, or FP8, which respectively require 4, 2, or 1 byte to store each single value in the tensor. We will have a full discussion of the different precisions and their trade-offs in the <a target="_self" href="#mixed_precision_training">Mixed Precision Training</a> section, for now let's just keep in mind that the memory requirements for these various format will be different and that will impact the memory usage of the items we need to store.</p>
349
 
350
  <p>So how can I quickly determine memory usage from these variable? One simple way is to do this empirically and just measure it.</p>
351
 
 
815
 
816
  <h4>Memory usage revisited</h4>
817
 
818
+ <p>You likely remember from <a target="_self" href="#memory_usage_in_transformers"> our previous section</a> the memory usage of optimizer states, gradients, and parameters during a standard training. Lets call our model's parameters count <d-math>\Psi</d-math> (previously N but here we use the original ZeRO paper notation). In <a target="_self" href="#mixed_precision_training">Mixed Precision Training</a> (more details in a later section) with the Adam optimizer, the memory usage for each item we need to store is:</p>
819
 
820
  <ul>
821
  <li>Model’s parameters (half precision i.e. bf16/fp16): <d-math>2\Psi</d-math></li>
 
2274
 
2275
  <p>In several places now we’ve mentioned how GPU and CPU operation can be asynchronous. In particular, the host code on the CPU can schedule workload on the GPU in a non-blocking way.</p>
2276
 
2277
+ <p>Non-blocking can be useful for overlapping communication and computation as we saw many times along our journey– but can be extended to the more general idea of trying to avoid at all cost going back and forth between host and GPU kernel commands.</p>
2278
+ <p>This idea is beautifully illustrated by <a href="https://horace.io/brrr_intro.html">Horace He</a> in these diagrams:</p>
2279
  <div style="display: flex; gap: 20px; align-items: flex-start;">
2280
  <div style="width: 50%;">
2281
  <img alt="image.png" src="/assets/images/fused_kernels1.png" style="width: 100%;" />
2282
+ <div class="figure-legend">
2283
  <p>A sequence of kernels requiring back and forth between global memory and compute units</p>
2284
  </div>
2285
+ </div>
2286
  <div style="width: 50%;">
2287
  <img alt="image.png" src="/assets/images/fused_kernels2.png" style="width: 100%;" />
2288
+ <div class="figure-legend">
2289
+ <p>Instead of sending our triangle back to global memory just to read it back again, we instead just do all of our operations in one go.</p>
2290
+ </div>
2291
  </div>
2292
  </div>
2293
 
2294
  <p>How can we avoid this back and forth? Well the best way is to make our GPU as autonomous as possible. This is achieved by packing as many successive compute operations together in a single kernel for the GPU to run, called a “Fused Kernel”.</p>
2295
 
2296
 
2297
+ <p>Fused kernel are especially efficient and simple to write for succession of point-like operations which are performed independently of each other on each input tokens. In this case, there is no point in bringing back computed values in Global Memory before moving them to SM memory and spinning up a new kernel. It’s much more efficient to keep all values locally until the succession of computation has been performed.</p>
2298
 
2299
+ <p>There are many places in a Transformer model where this "fusing" approach can be applied: every time we have a succession of point-wise operations e.g. in the computation involved in the Layer norms.</p>
2300
 
2301
  <p>We now have all the understanding necessary to marvel at a true masterpiece of kernel engineering: <strong><em>Flash Attention</em></strong></p>
2302
 
2303
  <h3>Flash Attention 1-3</h3>
2304
 
2305
+ <p>Flash attention was introduced by <a href="https://tridao.me">Tri Dao</a> and proposed to optimize the attention computations by writing custom CUDA kernels make them much faster *and* more memory efficient. The idea behind Flash Attention is to make efficient use of the various memories of the GPU to avoid relying too much on the slowest one: the global memory of the GPU.</p>
2306
+ <aside>Note that the global memory of the GPU is confusingly called the "High Bandwidth Memory", HBM 🫠</aside>
2307
 
2308
  <p>A basic implementation of the attention mechanism involve a lot of transfer between memory and workers. It requires materializing the S and P matrices in HBM which means that the results need to be sent to HBM and then back to SRAM for the next computations:</p>
2309
 
 
2330
 
2331
  <p>Flash-Attention is a master demonstration of the breakthrough improvements that can come when you take into account the internal memory/compute design of current GPU accelerators.</p>
2332
 
2333
+ <hr>
2334
+
2335
+ <p>The techniques described so far in this operation-fusion section have required us to implement modeling code changes and write custom kernels for certain operations in order to speed up training.</p>
2336
+ <p>In the final section of our low-level dive in the compute operations themselves, we will take a look at a range of methods that are agnostic to the modeling code and can be used for any model and are so widely used that they have become a standard in the industry: <strong>Mixed Precision Training</strong>!</p>
2337
 
2338
  <h3>Mixed Precision Training</h3>
2339
+
2340
+ <p>In various sections along this book, we've talked about lower precisions formats and their impact on the memory requirements for storing activations, parameters and optimizer states. It's now time to dive deeper in the details of these formats and understand better their trade-offs, advantages and limitations.</p>
2341
 
2342
  <p>Mixed Precision Training, as the name suggests, involves mixing different precisions when training. The default numerical precision of PyTorch tensors is single-precision floating point format or also called FP32 or float32 which means that every number stored takes up 32 bits or 4 bytes. The available bits to represent a number are divided into 3 parts:</p>
2343
 
 
2347
  <li>Exponent: controls the magnitude of the number</li>
2348
  </ul>
2349
 
2350
+ <p><img width="500px" alt="sign-mantissa-exponent.svg" src="/assets/images/sign-mantissa-exponent.svg" /></p>
2351
+
2352
+
2353
+
2354
  <p>The principle of floating point numbers can be easily illustrated by recalling the scientific notation of numbers, e.g. <d-math>- 5.734 \times 10^{7}</d-math>, where we first have the sign, followed by the mantissa an the exponent. As such we can represent numbers across a wide range of magnitudes with an adaptive precision. Although float32 is the default there is a range of floating point formats available in PyTorch:</p>
2355
 
2356
  <p></p>
 
2361
  <th><strong>Format</strong></th>
2362
  <th><strong>Total bits</strong></th>
2363
  <th><strong>Sign</strong></th>
 
2364
  <th><strong>Exponent</strong></th>
2365
+ <th><strong>Mantissa</strong></th>
2366
  </tr>
2367
  </thead>
2368
  <tbody>
 
2370
  <td>float32</td>
2371
  <td>32</td>
2372
  <td>1</td>
 
2373
  <td>8</td>
2374
+ <td>23</td>
2375
  </tr>
2376
  <tr>
2377
  <td>float16</td>
2378
  <td>16</td>
2379
  <td>1</td>
 
2380
  <td>5</td>
2381
+ <td>10</td>
2382
  </tr>
2383
  <tr>
2384
  <td>bfloat16</td>
2385
  <td>16</td>
2386
  <td>1</td>
 
2387
  <td>8</td>
2388
+ <td>7</td>
2389
  </tr>
2390
  <tr>
2391
  <td>float8 (e4m3)</td>
2392
  <td>8</td>
2393
  <td>1</td>
 
2394
  <td>4</td>
2395
+ <td>3</td>
2396
  </tr>
2397
  <tr>
2398
  <td>float8 (e5m2)</td>
2399
  <td>8</td>
2400
  <td>1</td>
 
2401
  <td>5</td>
2402
+ <td>2</td>
2403
  </tr>
2404
  </tbody>
2405
  </table>
 
2419
 
2420
  <p>We can see here that bfloat16 maintained the range of float32 over float16 but did this with the cost of sacrificing more precision. In case of float8 the situation is even more dire as e4m3 can represent 7 and e5m2 only 3 number on the interval 1-2.</p>
2421
 
2422
+ <p>A common metric to measure a formats resolution is epsilon: the first representable number after <d-math>1.00</d-math>. We can see that for the float32 format <d-math>10^{-4}</d-math> is an upper bound (it’s actually <d-math>1.19^{-7}</d-math>). For float16 it is <d-math>\tilde 10^{-3}</d-math> and for bfloat 10x higher still.</p>
2423
 
2424
+ <p>The idea of mixed precision training is to use some of these lower precisions formats while maintaining the performance of full precision training. </p>
2425
 
2426
+ <p>It turns out we <strong>can’t</strong> totally abandon float32 and usually will need to maintain some parts in full precision. This is why lower precision training is usually called <strong><em>mixed precision</em></strong> training. </p>
2427
 
2428
  <p>Let’s now take a look at training models with 16 bits and then see if we can take it a step further all the way down to 8 bits.</p>
2429
 
 
2436
  <ol>
2437
  <li><strong>FP32 copy of weights</strong>: There are two possible issues with float16 weights. During training some of the weights can become very small and will be rounded to 0. However, even if the weights themselves are not close to zero, if the updates are very small the difference in magnitude can cause the weights to underflow during the addition. Once the weights are zero they will remain 0 for the rest of training as there is no gradient signal coming through anymore.</li>
2438
  <li><strong>Loss scaling</strong>: We have a similar issue with the gradients as well as gradients tend to be much smaller than 1 and are thus at risk to underflow. A simple, yet effective, strategy is to scale the loss before the backward pass and unscale the gradients after the backward pass. This ensures that there is no underflow during the backward pass and the scaling is not affecting training as we unscale before processing the gradients further (e.g. clipping) and the optimization step. </li>
2439
+ <li><strong>Accumulation</strong>: Finally, when performing certain arithmetic operations in 16-bit precision such as averages or summations, we can also face under or overflows. A solution is then to accumulate intermediate results in float32 during the operation and only cast the final result back to 16 bit precision.</li>
2440
  </ol>
2441
 
2442
+ <p>With these techniques, we can get a stable training while benefitting from a higher throughput due to the faster, lower precision arithmetic operations. Naturally, as a curious reader and by now slightly addicted to maximizing the throughput you may ask the question: can we go further and faster than 16-bit precision? </p>
2443
 
2444
  <p>Maybe!</p>
2445
 
 
2451
 
2452
  <p>We know that instability increases as learning rates rise for a fixed model size<d-cite bibtex-key="wortsman2023smallscaleproxieslargescaletransformer"></d-cite>, making FP8 pretraining particularly tricky.</p>
2453
 
2454
+ <p>Here is an example of a typically divergent loss curve for FP8 training:</p>
2455
  <iframe class="l-body-outset" id="plotFP8Loss" src="/assets/data/fp8/fp8_training_loss_curves.html" height="520" width="1000" scrolling="no" frameborder="0"></iframe>
2456
  <!-- Hynek uncomment this once it's added to -->
2457
  <!-- <div class="l-body-outset" id="fragment-fp8_training_loss_curves"></div> -->
 
2460
 
2461
  <p><img alt="image.png" src="/assets/images/fp8_diagram.png" /></p>
2462
 
2463
+ <p>In order to switch from high precision (e.g. FP32 or BF16) to lower precision (e.g. FP16 or FP8) with smaller range, we need to normalize the range of activation values, for instance by computing their absolute maximum. DeepSeek-V3 further introduced a specific quantization scheme where the ranges are normalized per tile: 1x128 for inputs/activations and 128x128 for weights and scale elements. This makes the normalization less strongly impacted by outlier values in the activations. There is a number of additional tricks they proposed to further reduce the memory and communication footprint which you can follow in section 3.3. of the DeepSeek-V3 technical report<d-cite bibtex-key="deepseekai2024deepseekv3technicalreport"></d-cite>. </p>
2464
 
2465
  <p>Here’s a summary of a few known approaches to FP8 training:</p>
2466
 
 
2541
  </tbody>
2542
  </table>
2543
 
2544
+ <p>Overall, FP8 remains –in early 2025– an experimental technique and methods are still evolving. Given its obvious benefits, it will likely become the standard and soon replace bf16 mixed-precision. To follow an open-source implementations of FP8 training techniques, please head to the nanotron’s implementation in <a href="https://github.com/huggingface/nanotron/pull/70">this PR</a>. </p>
2545
 
2546
+ <p>Projecting further into the future, Blackwell, the next generation of NVIDIA chips, <a href="https://www.nvidia.com/en-us/data-center/technologies/blackwell-architecture/">have been announced </a> to support FP4 training, further speeding up training but without a doubt also introducing a new training stability challenge.</p>
2547
+
2548
+ <hr>
2549
 
2550
+ <p>This last section concluded our long journey in the land of fast and large model training on tens to thousands of GPUs. Time to slowly bring our GPU cluster to rest and take a step back to conclude on all we've learned along the way.</p>
2551
 
2552
  <h2>Conclusion</h2>
2553