Spaces:
Running
Running
more fixes thom (#61)
Browse files- updates (fba245a892e8971bd66633f3b5ffd0d478c4e6b8)
- assets/images/sign-mantissa-exponent.svg +3 -0
- dist/assets/images/sign-mantissa-exponent.svg +1 -0
- dist/index.html +41 -23
- src/index.html +41 -23
assets/images/sign-mantissa-exponent.svg
ADDED
|
dist/assets/images/sign-mantissa-exponent.svg
ADDED
|
dist/index.html
CHANGED
@@ -345,7 +345,7 @@
|
|
345 |
</p></div>
|
346 |
</div>
|
347 |
|
348 |
-
<p>These items are stored as tensors which come in different <em>shapes</em> and <em>precisions</em>. The <em>shapes</em> are determined by hyper-parameters such as batch size, sequence length, model hidden dimensions, attention heads, vocabulary size, and potential model sharding as we’ll see later. <em>Precision</em> refers to formats like FP32, BF16, or FP8, which respectively require 4, 2, or 1 byte to store each single value in the tensor.</p>
|
349 |
|
350 |
<p>So how can I quickly determine memory usage from these variable? One simple way is to do this empirically and just measure it.</p>
|
351 |
|
@@ -815,7 +815,7 @@
|
|
815 |
|
816 |
<h4>Memory usage revisited</h4>
|
817 |
|
818 |
-
<p>You likely remember from <a target="_self" href="#memory_usage_in_transformers"> our previous section</a> the memory usage of optimizer states, gradients, and parameters during a standard training. Lets call our model's parameters count <d-math>\Psi</d-math> (previously N but here we use the original ZeRO paper notation). In
|
819 |
|
820 |
<ul>
|
821 |
<li>Model’s parameters (half precision i.e. bf16/fp16): <d-math>2\Psi</d-math></li>
|
@@ -2274,30 +2274,36 @@
|
|
2274 |
|
2275 |
<p>In several places now we’ve mentioned how GPU and CPU operation can be asynchronous. In particular, the host code on the CPU can schedule workload on the GPU in a non-blocking way.</p>
|
2276 |
|
2277 |
-
<p>Non-blocking can be useful for overlapping communication and computation as we saw
|
|
|
2278 |
<div style="display: flex; gap: 20px; align-items: flex-start;">
|
2279 |
<div style="width: 50%;">
|
2280 |
<img alt="image.png" src="/assets/images/fused_kernels1.png" style="width: 100%;" />
|
|
|
2281 |
<p>A sequence of kernels requiring back and forth between global memory and compute units</p>
|
2282 |
</div>
|
|
|
2283 |
<div style="width: 50%;">
|
2284 |
<img alt="image.png" src="/assets/images/fused_kernels2.png" style="width: 100%;" />
|
2285 |
-
<
|
|
|
|
|
2286 |
</div>
|
2287 |
</div>
|
2288 |
|
2289 |
<p>How can we avoid this back and forth? Well the best way is to make our GPU as autonomous as possible. This is achieved by packing as many successive compute operations together in a single kernel for the GPU to run, called a “Fused Kernel”.</p>
|
2290 |
|
2291 |
|
2292 |
-
<p>Fused kernel are especially efficient and simple to write for succession of point-like operations which are performed independently of each other on each input tokens. In this case, there is no point in bringing back computed values in Global Memory before moving them to SM memory and spinning up a new kernel. It’s much more efficient to keep all values
|
2293 |
|
2294 |
-
<p>
|
2295 |
|
2296 |
<p>We now have all the understanding necessary to marvel at a true masterpiece of kernel engineering: <strong><em>Flash Attention</em></strong></p>
|
2297 |
|
2298 |
<h3>Flash Attention 1-3</h3>
|
2299 |
|
2300 |
-
<p>Flash attention
|
|
|
2301 |
|
2302 |
<p>A basic implementation of the attention mechanism involve a lot of transfer between memory and workers. It requires materializing the S and P matrices in HBM which means that the results need to be sent to HBM and then back to SRAM for the next computations:</p>
|
2303 |
|
@@ -2324,9 +2330,14 @@
|
|
2324 |
|
2325 |
<p>Flash-Attention is a master demonstration of the breakthrough improvements that can come when you take into account the internal memory/compute design of current GPU accelerators.</p>
|
2326 |
|
2327 |
-
<
|
|
|
|
|
|
|
2328 |
|
2329 |
<h3>Mixed Precision Training</h3>
|
|
|
|
|
2330 |
|
2331 |
<p>Mixed Precision Training, as the name suggests, involves mixing different precisions when training. The default numerical precision of PyTorch tensors is single-precision floating point format or also called FP32 or float32 which means that every number stored takes up 32 bits or 4 bytes. The available bits to represent a number are divided into 3 parts:</p>
|
2332 |
|
@@ -2336,6 +2347,10 @@
|
|
2336 |
<li>Exponent: controls the magnitude of the number</li>
|
2337 |
</ul>
|
2338 |
|
|
|
|
|
|
|
|
|
2339 |
<p>The principle of floating point numbers can be easily illustrated by recalling the scientific notation of numbers, e.g. <d-math>- 5.734 \times 10^{7}</d-math>, where we first have the sign, followed by the mantissa an the exponent. As such we can represent numbers across a wide range of magnitudes with an adaptive precision. Although float32 is the default there is a range of floating point formats available in PyTorch:</p>
|
2340 |
|
2341 |
<p></p>
|
@@ -2346,8 +2361,8 @@
|
|
2346 |
<th><strong>Format</strong></th>
|
2347 |
<th><strong>Total bits</strong></th>
|
2348 |
<th><strong>Sign</strong></th>
|
2349 |
-
<th><strong>Mantissa</strong></th>
|
2350 |
<th><strong>Exponent</strong></th>
|
|
|
2351 |
</tr>
|
2352 |
</thead>
|
2353 |
<tbody>
|
@@ -2355,36 +2370,36 @@
|
|
2355 |
<td>float32</td>
|
2356 |
<td>32</td>
|
2357 |
<td>1</td>
|
2358 |
-
<td>23</td>
|
2359 |
<td>8</td>
|
|
|
2360 |
</tr>
|
2361 |
<tr>
|
2362 |
<td>float16</td>
|
2363 |
<td>16</td>
|
2364 |
<td>1</td>
|
2365 |
-
<td>10</td>
|
2366 |
<td>5</td>
|
|
|
2367 |
</tr>
|
2368 |
<tr>
|
2369 |
<td>bfloat16</td>
|
2370 |
<td>16</td>
|
2371 |
<td>1</td>
|
2372 |
-
<td>7</td>
|
2373 |
<td>8</td>
|
|
|
2374 |
</tr>
|
2375 |
<tr>
|
2376 |
<td>float8 (e4m3)</td>
|
2377 |
<td>8</td>
|
2378 |
<td>1</td>
|
2379 |
-
<td>3</td>
|
2380 |
<td>4</td>
|
|
|
2381 |
</tr>
|
2382 |
<tr>
|
2383 |
<td>float8 (e5m2)</td>
|
2384 |
<td>8</td>
|
2385 |
<td>1</td>
|
2386 |
-
<td>2</td>
|
2387 |
<td>5</td>
|
|
|
2388 |
</tr>
|
2389 |
</tbody>
|
2390 |
</table>
|
@@ -2404,11 +2419,11 @@
|
|
2404 |
|
2405 |
<p>We can see here that bfloat16 maintained the range of float32 over float16 but did this with the cost of sacrificing more precision. In case of float8 the situation is even more dire as e4m3 can represent 7 and e5m2 only 3 number on the interval 1-2.</p>
|
2406 |
|
2407 |
-
<p>A common metric to measure a formats resolution is epsilon: the first representable number after 1.00
|
2408 |
|
2409 |
-
<p>The idea of mixed precision training is to use some of these lower precisions formats while maintaining the performance of full precision training.
|
2410 |
|
2411 |
-
<p>This is why lower precision training is usually called <strong><em>mixed precision</em></strong> training. </p>
|
2412 |
|
2413 |
<p>Let’s now take a look at training models with 16 bits and then see if we can take it a step further all the way down to 8 bits.</p>
|
2414 |
|
@@ -2421,10 +2436,10 @@
|
|
2421 |
<ol>
|
2422 |
<li><strong>FP32 copy of weights</strong>: There are two possible issues with float16 weights. During training some of the weights can become very small and will be rounded to 0. However, even if the weights themselves are not close to zero, if the updates are very small the difference in magnitude can cause the weights to underflow during the addition. Once the weights are zero they will remain 0 for the rest of training as there is no gradient signal coming through anymore.</li>
|
2423 |
<li><strong>Loss scaling</strong>: We have a similar issue with the gradients as well as gradients tend to be much smaller than 1 and are thus at risk to underflow. A simple, yet effective, strategy is to scale the loss before the backward pass and unscale the gradients after the backward pass. This ensures that there is no underflow during the backward pass and the scaling is not affecting training as we unscale before processing the gradients further (e.g. clipping) and the optimization step. </li>
|
2424 |
-
<li><strong>Accumulation</strong>: Finally, when performing arithmetic operations in
|
2425 |
</ol>
|
2426 |
|
2427 |
-
<p>With these techniques,
|
2428 |
|
2429 |
<p>Maybe!</p>
|
2430 |
|
@@ -2436,6 +2451,7 @@
|
|
2436 |
|
2437 |
<p>We know that instability increases as learning rates rise for a fixed model size<d-cite bibtex-key="wortsman2023smallscaleproxieslargescaletransformer"></d-cite>, making FP8 pretraining particularly tricky.</p>
|
2438 |
|
|
|
2439 |
<iframe class="l-body-outset" id="plotFP8Loss" src="/assets/data/fp8/fp8_training_loss_curves.html" height="520" width="1000" scrolling="no" frameborder="0"></iframe>
|
2440 |
<!-- Hynek uncomment this once it's added to -->
|
2441 |
<!-- <div class="l-body-outset" id="fragment-fp8_training_loss_curves"></div> -->
|
@@ -2444,7 +2460,7 @@
|
|
2444 |
|
2445 |
<p><img alt="image.png" src="/assets/images/fp8_diagram.png" /></p>
|
2446 |
|
2447 |
-
<p>In order to switch from high precision (e.g. FP32 or BF16) to lower precision (e.g. FP16 or FP8) with smaller range, we need to normalize the range of values by computing
|
2448 |
|
2449 |
<p>Here’s a summary of a few known approaches to FP8 training:</p>
|
2450 |
|
@@ -2525,11 +2541,13 @@
|
|
2525 |
</tbody>
|
2526 |
</table>
|
2527 |
|
2528 |
-
<p>Overall, FP8
|
2529 |
|
2530 |
-
<p>
|
|
|
|
|
2531 |
|
2532 |
-
<p>
|
2533 |
|
2534 |
<h2>Conclusion</h2>
|
2535 |
|
|
|
345 |
</p></div>
|
346 |
</div>
|
347 |
|
348 |
+
<p>These items are stored as tensors which come in different <em>shapes</em> and <em>precisions</em>. The <em>shapes</em> are determined by hyper-parameters such as batch size, sequence length, model hidden dimensions, attention heads, vocabulary size, and potential model sharding as we’ll see later. <em>Precision</em> refers to formats like FP32, BF16, or FP8, which respectively require 4, 2, or 1 byte to store each single value in the tensor. We will have a full discussion of the different precisions and their trade-offs in the <a target="_self" href="#mixed_precision_training">Mixed Precision Training</a> section, for now let's just keep in mind that the memory requirements for these various format will be different and that will impact the memory usage of the items we need to store.</p>
|
349 |
|
350 |
<p>So how can I quickly determine memory usage from these variable? One simple way is to do this empirically and just measure it.</p>
|
351 |
|
|
|
815 |
|
816 |
<h4>Memory usage revisited</h4>
|
817 |
|
818 |
+
<p>You likely remember from <a target="_self" href="#memory_usage_in_transformers"> our previous section</a> the memory usage of optimizer states, gradients, and parameters during a standard training. Lets call our model's parameters count <d-math>\Psi</d-math> (previously N but here we use the original ZeRO paper notation). In <a target="_self" href="#mixed_precision_training">Mixed Precision Training</a> (more details in a later section) with the Adam optimizer, the memory usage for each item we need to store is:</p>
|
819 |
|
820 |
<ul>
|
821 |
<li>Model’s parameters (half precision i.e. bf16/fp16): <d-math>2\Psi</d-math></li>
|
|
|
2274 |
|
2275 |
<p>In several places now we’ve mentioned how GPU and CPU operation can be asynchronous. In particular, the host code on the CPU can schedule workload on the GPU in a non-blocking way.</p>
|
2276 |
|
2277 |
+
<p>Non-blocking can be useful for overlapping communication and computation –as we saw many times along our journey– but can be extended to the more general idea of trying to avoid at all cost going back and forth between host and GPU kernel commands.</p>
|
2278 |
+
<p>This idea is beautifully illustrated by <a href="https://horace.io/brrr_intro.html">Horace He</a> in these diagrams:</p>
|
2279 |
<div style="display: flex; gap: 20px; align-items: flex-start;">
|
2280 |
<div style="width: 50%;">
|
2281 |
<img alt="image.png" src="/assets/images/fused_kernels1.png" style="width: 100%;" />
|
2282 |
+
<div class="figure-legend">
|
2283 |
<p>A sequence of kernels requiring back and forth between global memory and compute units</p>
|
2284 |
</div>
|
2285 |
+
</div>
|
2286 |
<div style="width: 50%;">
|
2287 |
<img alt="image.png" src="/assets/images/fused_kernels2.png" style="width: 100%;" />
|
2288 |
+
<div class="figure-legend">
|
2289 |
+
<p>Instead of sending our triangle back to global memory just to read it back again, we instead just do all of our operations in one go.</p>
|
2290 |
+
</div>
|
2291 |
</div>
|
2292 |
</div>
|
2293 |
|
2294 |
<p>How can we avoid this back and forth? Well the best way is to make our GPU as autonomous as possible. This is achieved by packing as many successive compute operations together in a single kernel for the GPU to run, called a “Fused Kernel”.</p>
|
2295 |
|
2296 |
|
2297 |
+
<p>Fused kernel are especially efficient and simple to write for succession of point-like operations which are performed independently of each other on each input tokens. In this case, there is no point in bringing back computed values in Global Memory before moving them to SM memory and spinning up a new kernel. It’s much more efficient to keep all values locally until the succession of computation has been performed.</p>
|
2298 |
|
2299 |
+
<p>There are many places in a Transformer model where this "fusing" approach can be applied: every time we have a succession of point-wise operations e.g. in the computation involved in the Layer norms.</p>
|
2300 |
|
2301 |
<p>We now have all the understanding necessary to marvel at a true masterpiece of kernel engineering: <strong><em>Flash Attention</em></strong></p>
|
2302 |
|
2303 |
<h3>Flash Attention 1-3</h3>
|
2304 |
|
2305 |
+
<p>Flash attention was introduced by <a href="https://tridao.me">Tri Dao</a> and proposed to optimize the attention computations by writing custom CUDA kernels make them much faster *and* more memory efficient. The idea behind Flash Attention is to make efficient use of the various memories of the GPU to avoid relying too much on the slowest one: the global memory of the GPU.</p>
|
2306 |
+
<aside>Note that the global memory of the GPU is confusingly called the "High Bandwidth Memory", HBM 🫠</aside>
|
2307 |
|
2308 |
<p>A basic implementation of the attention mechanism involve a lot of transfer between memory and workers. It requires materializing the S and P matrices in HBM which means that the results need to be sent to HBM and then back to SRAM for the next computations:</p>
|
2309 |
|
|
|
2330 |
|
2331 |
<p>Flash-Attention is a master demonstration of the breakthrough improvements that can come when you take into account the internal memory/compute design of current GPU accelerators.</p>
|
2332 |
|
2333 |
+
<hr>
|
2334 |
+
|
2335 |
+
<p>The techniques described so far in this operation-fusion section have required us to implement modeling code changes and write custom kernels for certain operations in order to speed up training.</p>
|
2336 |
+
<p>In the final section of our low-level dive in the compute operations themselves, we will take a look at a range of methods that are agnostic to the modeling code and can be used for any model and are so widely used that they have become a standard in the industry: <strong>Mixed Precision Training</strong>!</p>
|
2337 |
|
2338 |
<h3>Mixed Precision Training</h3>
|
2339 |
+
|
2340 |
+
<p>In various sections along this book, we've talked about lower precisions formats and their impact on the memory requirements for storing activations, parameters and optimizer states. It's now time to dive deeper in the details of these formats and understand better their trade-offs, advantages and limitations.</p>
|
2341 |
|
2342 |
<p>Mixed Precision Training, as the name suggests, involves mixing different precisions when training. The default numerical precision of PyTorch tensors is single-precision floating point format or also called FP32 or float32 which means that every number stored takes up 32 bits or 4 bytes. The available bits to represent a number are divided into 3 parts:</p>
|
2343 |
|
|
|
2347 |
<li>Exponent: controls the magnitude of the number</li>
|
2348 |
</ul>
|
2349 |
|
2350 |
+
<p><img width="500px" alt="sign-mantissa-exponent.svg" src="/assets/images/sign-mantissa-exponent.svg" /></p>
|
2351 |
+
|
2352 |
+
|
2353 |
+
|
2354 |
<p>The principle of floating point numbers can be easily illustrated by recalling the scientific notation of numbers, e.g. <d-math>- 5.734 \times 10^{7}</d-math>, where we first have the sign, followed by the mantissa an the exponent. As such we can represent numbers across a wide range of magnitudes with an adaptive precision. Although float32 is the default there is a range of floating point formats available in PyTorch:</p>
|
2355 |
|
2356 |
<p></p>
|
|
|
2361 |
<th><strong>Format</strong></th>
|
2362 |
<th><strong>Total bits</strong></th>
|
2363 |
<th><strong>Sign</strong></th>
|
|
|
2364 |
<th><strong>Exponent</strong></th>
|
2365 |
+
<th><strong>Mantissa</strong></th>
|
2366 |
</tr>
|
2367 |
</thead>
|
2368 |
<tbody>
|
|
|
2370 |
<td>float32</td>
|
2371 |
<td>32</td>
|
2372 |
<td>1</td>
|
|
|
2373 |
<td>8</td>
|
2374 |
+
<td>23</td>
|
2375 |
</tr>
|
2376 |
<tr>
|
2377 |
<td>float16</td>
|
2378 |
<td>16</td>
|
2379 |
<td>1</td>
|
|
|
2380 |
<td>5</td>
|
2381 |
+
<td>10</td>
|
2382 |
</tr>
|
2383 |
<tr>
|
2384 |
<td>bfloat16</td>
|
2385 |
<td>16</td>
|
2386 |
<td>1</td>
|
|
|
2387 |
<td>8</td>
|
2388 |
+
<td>7</td>
|
2389 |
</tr>
|
2390 |
<tr>
|
2391 |
<td>float8 (e4m3)</td>
|
2392 |
<td>8</td>
|
2393 |
<td>1</td>
|
|
|
2394 |
<td>4</td>
|
2395 |
+
<td>3</td>
|
2396 |
</tr>
|
2397 |
<tr>
|
2398 |
<td>float8 (e5m2)</td>
|
2399 |
<td>8</td>
|
2400 |
<td>1</td>
|
|
|
2401 |
<td>5</td>
|
2402 |
+
<td>2</td>
|
2403 |
</tr>
|
2404 |
</tbody>
|
2405 |
</table>
|
|
|
2419 |
|
2420 |
<p>We can see here that bfloat16 maintained the range of float32 over float16 but did this with the cost of sacrificing more precision. In case of float8 the situation is even more dire as e4m3 can represent 7 and e5m2 only 3 number on the interval 1-2.</p>
|
2421 |
|
2422 |
+
<p>A common metric to measure a formats resolution is epsilon: the first representable number after <d-math>1.00</d-math>. We can see that for the float32 format <d-math>10^{-4}</d-math> is an upper bound (it’s actually <d-math>1.19^{-7}</d-math>). For float16 it is <d-math>\tilde 10^{-3}</d-math> and for bfloat 10x higher still.</p>
|
2423 |
|
2424 |
+
<p>The idea of mixed precision training is to use some of these lower precisions formats while maintaining the performance of full precision training. </p>
|
2425 |
|
2426 |
+
<p>It turns out we <strong>can’t</strong> totally abandon float32 and usually will need to maintain some parts in full precision. This is why lower precision training is usually called <strong><em>mixed precision</em></strong> training. </p>
|
2427 |
|
2428 |
<p>Let’s now take a look at training models with 16 bits and then see if we can take it a step further all the way down to 8 bits.</p>
|
2429 |
|
|
|
2436 |
<ol>
|
2437 |
<li><strong>FP32 copy of weights</strong>: There are two possible issues with float16 weights. During training some of the weights can become very small and will be rounded to 0. However, even if the weights themselves are not close to zero, if the updates are very small the difference in magnitude can cause the weights to underflow during the addition. Once the weights are zero they will remain 0 for the rest of training as there is no gradient signal coming through anymore.</li>
|
2438 |
<li><strong>Loss scaling</strong>: We have a similar issue with the gradients as well as gradients tend to be much smaller than 1 and are thus at risk to underflow. A simple, yet effective, strategy is to scale the loss before the backward pass and unscale the gradients after the backward pass. This ensures that there is no underflow during the backward pass and the scaling is not affecting training as we unscale before processing the gradients further (e.g. clipping) and the optimization step. </li>
|
2439 |
+
<li><strong>Accumulation</strong>: Finally, when performing certain arithmetic operations in 16-bit precision such as averages or summations, we can also face under or overflows. A solution is then to accumulate intermediate results in float32 during the operation and only cast the final result back to 16 bit precision.</li>
|
2440 |
</ol>
|
2441 |
|
2442 |
+
<p>With these techniques, we can get a stable training while benefitting from a higher throughput due to the faster, lower precision arithmetic operations. Naturally, as a curious reader –and by now slightly addicted to maximizing the throughput– you may ask the question: can we go further and faster than 16-bit precision? </p>
|
2443 |
|
2444 |
<p>Maybe!</p>
|
2445 |
|
|
|
2451 |
|
2452 |
<p>We know that instability increases as learning rates rise for a fixed model size<d-cite bibtex-key="wortsman2023smallscaleproxieslargescaletransformer"></d-cite>, making FP8 pretraining particularly tricky.</p>
|
2453 |
|
2454 |
+
<p>Here is an example of a typically divergent loss curve for FP8 training:</p>
|
2455 |
<iframe class="l-body-outset" id="plotFP8Loss" src="/assets/data/fp8/fp8_training_loss_curves.html" height="520" width="1000" scrolling="no" frameborder="0"></iframe>
|
2456 |
<!-- Hynek uncomment this once it's added to -->
|
2457 |
<!-- <div class="l-body-outset" id="fragment-fp8_training_loss_curves"></div> -->
|
|
|
2460 |
|
2461 |
<p><img alt="image.png" src="/assets/images/fp8_diagram.png" /></p>
|
2462 |
|
2463 |
+
<p>In order to switch from high precision (e.g. FP32 or BF16) to lower precision (e.g. FP16 or FP8) with smaller range, we need to normalize the range of activation values, for instance by computing their absolute maximum. DeepSeek-V3 further introduced a specific quantization scheme where the ranges are normalized per tile: 1x128 for inputs/activations and 128x128 for weights and scale elements. This makes the normalization less strongly impacted by outlier values in the activations. There is a number of additional tricks they proposed to further reduce the memory and communication footprint which you can follow in section 3.3. of the DeepSeek-V3 technical report<d-cite bibtex-key="deepseekai2024deepseekv3technicalreport"></d-cite>. </p>
|
2464 |
|
2465 |
<p>Here’s a summary of a few known approaches to FP8 training:</p>
|
2466 |
|
|
|
2541 |
</tbody>
|
2542 |
</table>
|
2543 |
|
2544 |
+
<p>Overall, FP8 remains –in early 2025– an experimental technique and methods are still evolving. Given its obvious benefits, it will likely become the standard and soon replace bf16 mixed-precision. To follow an open-source implementations of FP8 training techniques, please head to the nanotron’s implementation in <a href="https://github.com/huggingface/nanotron/pull/70">this PR</a>. </p>
|
2545 |
|
2546 |
+
<p>Projecting further into the future, Blackwell, the next generation of NVIDIA chips, <a href="https://www.nvidia.com/en-us/data-center/technologies/blackwell-architecture/">have been announced </a> to support FP4 training, further speeding up training but without a doubt also introducing a new training stability challenge.</p>
|
2547 |
+
|
2548 |
+
<hr>
|
2549 |
|
2550 |
+
<p>This last section concluded our long journey in the land of fast and large model training on tens to thousands of GPUs. Time to slowly bring our GPU cluster to rest and take a step back to conclude on all we've learned along the way.</p>
|
2551 |
|
2552 |
<h2>Conclusion</h2>
|
2553 |
|
src/index.html
CHANGED
@@ -345,7 +345,7 @@
|
|
345 |
</p></div>
|
346 |
</div>
|
347 |
|
348 |
-
<p>These items are stored as tensors which come in different <em>shapes</em> and <em>precisions</em>. The <em>shapes</em> are determined by hyper-parameters such as batch size, sequence length, model hidden dimensions, attention heads, vocabulary size, and potential model sharding as we’ll see later. <em>Precision</em> refers to formats like FP32, BF16, or FP8, which respectively require 4, 2, or 1 byte to store each single value in the tensor.</p>
|
349 |
|
350 |
<p>So how can I quickly determine memory usage from these variable? One simple way is to do this empirically and just measure it.</p>
|
351 |
|
@@ -815,7 +815,7 @@
|
|
815 |
|
816 |
<h4>Memory usage revisited</h4>
|
817 |
|
818 |
-
<p>You likely remember from <a target="_self" href="#memory_usage_in_transformers"> our previous section</a> the memory usage of optimizer states, gradients, and parameters during a standard training. Lets call our model's parameters count <d-math>\Psi</d-math> (previously N but here we use the original ZeRO paper notation). In
|
819 |
|
820 |
<ul>
|
821 |
<li>Model’s parameters (half precision i.e. bf16/fp16): <d-math>2\Psi</d-math></li>
|
@@ -2274,30 +2274,36 @@
|
|
2274 |
|
2275 |
<p>In several places now we’ve mentioned how GPU and CPU operation can be asynchronous. In particular, the host code on the CPU can schedule workload on the GPU in a non-blocking way.</p>
|
2276 |
|
2277 |
-
<p>Non-blocking can be useful for overlapping communication and computation as we saw
|
|
|
2278 |
<div style="display: flex; gap: 20px; align-items: flex-start;">
|
2279 |
<div style="width: 50%;">
|
2280 |
<img alt="image.png" src="/assets/images/fused_kernels1.png" style="width: 100%;" />
|
|
|
2281 |
<p>A sequence of kernels requiring back and forth between global memory and compute units</p>
|
2282 |
</div>
|
|
|
2283 |
<div style="width: 50%;">
|
2284 |
<img alt="image.png" src="/assets/images/fused_kernels2.png" style="width: 100%;" />
|
2285 |
-
<
|
|
|
|
|
2286 |
</div>
|
2287 |
</div>
|
2288 |
|
2289 |
<p>How can we avoid this back and forth? Well the best way is to make our GPU as autonomous as possible. This is achieved by packing as many successive compute operations together in a single kernel for the GPU to run, called a “Fused Kernel”.</p>
|
2290 |
|
2291 |
|
2292 |
-
<p>Fused kernel are especially efficient and simple to write for succession of point-like operations which are performed independently of each other on each input tokens. In this case, there is no point in bringing back computed values in Global Memory before moving them to SM memory and spinning up a new kernel. It’s much more efficient to keep all values
|
2293 |
|
2294 |
-
<p>
|
2295 |
|
2296 |
<p>We now have all the understanding necessary to marvel at a true masterpiece of kernel engineering: <strong><em>Flash Attention</em></strong></p>
|
2297 |
|
2298 |
<h3>Flash Attention 1-3</h3>
|
2299 |
|
2300 |
-
<p>Flash attention
|
|
|
2301 |
|
2302 |
<p>A basic implementation of the attention mechanism involve a lot of transfer between memory and workers. It requires materializing the S and P matrices in HBM which means that the results need to be sent to HBM and then back to SRAM for the next computations:</p>
|
2303 |
|
@@ -2324,9 +2330,14 @@
|
|
2324 |
|
2325 |
<p>Flash-Attention is a master demonstration of the breakthrough improvements that can come when you take into account the internal memory/compute design of current GPU accelerators.</p>
|
2326 |
|
2327 |
-
<
|
|
|
|
|
|
|
2328 |
|
2329 |
<h3>Mixed Precision Training</h3>
|
|
|
|
|
2330 |
|
2331 |
<p>Mixed Precision Training, as the name suggests, involves mixing different precisions when training. The default numerical precision of PyTorch tensors is single-precision floating point format or also called FP32 or float32 which means that every number stored takes up 32 bits or 4 bytes. The available bits to represent a number are divided into 3 parts:</p>
|
2332 |
|
@@ -2336,6 +2347,10 @@
|
|
2336 |
<li>Exponent: controls the magnitude of the number</li>
|
2337 |
</ul>
|
2338 |
|
|
|
|
|
|
|
|
|
2339 |
<p>The principle of floating point numbers can be easily illustrated by recalling the scientific notation of numbers, e.g. <d-math>- 5.734 \times 10^{7}</d-math>, where we first have the sign, followed by the mantissa an the exponent. As such we can represent numbers across a wide range of magnitudes with an adaptive precision. Although float32 is the default there is a range of floating point formats available in PyTorch:</p>
|
2340 |
|
2341 |
<p></p>
|
@@ -2346,8 +2361,8 @@
|
|
2346 |
<th><strong>Format</strong></th>
|
2347 |
<th><strong>Total bits</strong></th>
|
2348 |
<th><strong>Sign</strong></th>
|
2349 |
-
<th><strong>Mantissa</strong></th>
|
2350 |
<th><strong>Exponent</strong></th>
|
|
|
2351 |
</tr>
|
2352 |
</thead>
|
2353 |
<tbody>
|
@@ -2355,36 +2370,36 @@
|
|
2355 |
<td>float32</td>
|
2356 |
<td>32</td>
|
2357 |
<td>1</td>
|
2358 |
-
<td>23</td>
|
2359 |
<td>8</td>
|
|
|
2360 |
</tr>
|
2361 |
<tr>
|
2362 |
<td>float16</td>
|
2363 |
<td>16</td>
|
2364 |
<td>1</td>
|
2365 |
-
<td>10</td>
|
2366 |
<td>5</td>
|
|
|
2367 |
</tr>
|
2368 |
<tr>
|
2369 |
<td>bfloat16</td>
|
2370 |
<td>16</td>
|
2371 |
<td>1</td>
|
2372 |
-
<td>7</td>
|
2373 |
<td>8</td>
|
|
|
2374 |
</tr>
|
2375 |
<tr>
|
2376 |
<td>float8 (e4m3)</td>
|
2377 |
<td>8</td>
|
2378 |
<td>1</td>
|
2379 |
-
<td>3</td>
|
2380 |
<td>4</td>
|
|
|
2381 |
</tr>
|
2382 |
<tr>
|
2383 |
<td>float8 (e5m2)</td>
|
2384 |
<td>8</td>
|
2385 |
<td>1</td>
|
2386 |
-
<td>2</td>
|
2387 |
<td>5</td>
|
|
|
2388 |
</tr>
|
2389 |
</tbody>
|
2390 |
</table>
|
@@ -2404,11 +2419,11 @@
|
|
2404 |
|
2405 |
<p>We can see here that bfloat16 maintained the range of float32 over float16 but did this with the cost of sacrificing more precision. In case of float8 the situation is even more dire as e4m3 can represent 7 and e5m2 only 3 number on the interval 1-2.</p>
|
2406 |
|
2407 |
-
<p>A common metric to measure a formats resolution is epsilon: the first representable number after 1.00
|
2408 |
|
2409 |
-
<p>The idea of mixed precision training is to use some of these lower precisions formats while maintaining the performance of full precision training.
|
2410 |
|
2411 |
-
<p>This is why lower precision training is usually called <strong><em>mixed precision</em></strong> training. </p>
|
2412 |
|
2413 |
<p>Let’s now take a look at training models with 16 bits and then see if we can take it a step further all the way down to 8 bits.</p>
|
2414 |
|
@@ -2421,10 +2436,10 @@
|
|
2421 |
<ol>
|
2422 |
<li><strong>FP32 copy of weights</strong>: There are two possible issues with float16 weights. During training some of the weights can become very small and will be rounded to 0. However, even if the weights themselves are not close to zero, if the updates are very small the difference in magnitude can cause the weights to underflow during the addition. Once the weights are zero they will remain 0 for the rest of training as there is no gradient signal coming through anymore.</li>
|
2423 |
<li><strong>Loss scaling</strong>: We have a similar issue with the gradients as well as gradients tend to be much smaller than 1 and are thus at risk to underflow. A simple, yet effective, strategy is to scale the loss before the backward pass and unscale the gradients after the backward pass. This ensures that there is no underflow during the backward pass and the scaling is not affecting training as we unscale before processing the gradients further (e.g. clipping) and the optimization step. </li>
|
2424 |
-
<li><strong>Accumulation</strong>: Finally, when performing arithmetic operations in
|
2425 |
</ol>
|
2426 |
|
2427 |
-
<p>With these techniques,
|
2428 |
|
2429 |
<p>Maybe!</p>
|
2430 |
|
@@ -2436,6 +2451,7 @@
|
|
2436 |
|
2437 |
<p>We know that instability increases as learning rates rise for a fixed model size<d-cite bibtex-key="wortsman2023smallscaleproxieslargescaletransformer"></d-cite>, making FP8 pretraining particularly tricky.</p>
|
2438 |
|
|
|
2439 |
<iframe class="l-body-outset" id="plotFP8Loss" src="/assets/data/fp8/fp8_training_loss_curves.html" height="520" width="1000" scrolling="no" frameborder="0"></iframe>
|
2440 |
<!-- Hynek uncomment this once it's added to -->
|
2441 |
<!-- <div class="l-body-outset" id="fragment-fp8_training_loss_curves"></div> -->
|
@@ -2444,7 +2460,7 @@
|
|
2444 |
|
2445 |
<p><img alt="image.png" src="/assets/images/fp8_diagram.png" /></p>
|
2446 |
|
2447 |
-
<p>In order to switch from high precision (e.g. FP32 or BF16) to lower precision (e.g. FP16 or FP8) with smaller range, we need to normalize the range of values by computing
|
2448 |
|
2449 |
<p>Here’s a summary of a few known approaches to FP8 training:</p>
|
2450 |
|
@@ -2525,11 +2541,13 @@
|
|
2525 |
</tbody>
|
2526 |
</table>
|
2527 |
|
2528 |
-
<p>Overall, FP8
|
2529 |
|
2530 |
-
<p>
|
|
|
|
|
2531 |
|
2532 |
-
<p>
|
2533 |
|
2534 |
<h2>Conclusion</h2>
|
2535 |
|
|
|
345 |
</p></div>
|
346 |
</div>
|
347 |
|
348 |
+
<p>These items are stored as tensors which come in different <em>shapes</em> and <em>precisions</em>. The <em>shapes</em> are determined by hyper-parameters such as batch size, sequence length, model hidden dimensions, attention heads, vocabulary size, and potential model sharding as we’ll see later. <em>Precision</em> refers to formats like FP32, BF16, or FP8, which respectively require 4, 2, or 1 byte to store each single value in the tensor. We will have a full discussion of the different precisions and their trade-offs in the <a target="_self" href="#mixed_precision_training">Mixed Precision Training</a> section, for now let's just keep in mind that the memory requirements for these various format will be different and that will impact the memory usage of the items we need to store.</p>
|
349 |
|
350 |
<p>So how can I quickly determine memory usage from these variable? One simple way is to do this empirically and just measure it.</p>
|
351 |
|
|
|
815 |
|
816 |
<h4>Memory usage revisited</h4>
|
817 |
|
818 |
+
<p>You likely remember from <a target="_self" href="#memory_usage_in_transformers"> our previous section</a> the memory usage of optimizer states, gradients, and parameters during a standard training. Lets call our model's parameters count <d-math>\Psi</d-math> (previously N but here we use the original ZeRO paper notation). In <a target="_self" href="#mixed_precision_training">Mixed Precision Training</a> (more details in a later section) with the Adam optimizer, the memory usage for each item we need to store is:</p>
|
819 |
|
820 |
<ul>
|
821 |
<li>Model’s parameters (half precision i.e. bf16/fp16): <d-math>2\Psi</d-math></li>
|
|
|
2274 |
|
2275 |
<p>In several places now we’ve mentioned how GPU and CPU operation can be asynchronous. In particular, the host code on the CPU can schedule workload on the GPU in a non-blocking way.</p>
|
2276 |
|
2277 |
+
<p>Non-blocking can be useful for overlapping communication and computation –as we saw many times along our journey– but can be extended to the more general idea of trying to avoid at all cost going back and forth between host and GPU kernel commands.</p>
|
2278 |
+
<p>This idea is beautifully illustrated by <a href="https://horace.io/brrr_intro.html">Horace He</a> in these diagrams:</p>
|
2279 |
<div style="display: flex; gap: 20px; align-items: flex-start;">
|
2280 |
<div style="width: 50%;">
|
2281 |
<img alt="image.png" src="/assets/images/fused_kernels1.png" style="width: 100%;" />
|
2282 |
+
<div class="figure-legend">
|
2283 |
<p>A sequence of kernels requiring back and forth between global memory and compute units</p>
|
2284 |
</div>
|
2285 |
+
</div>
|
2286 |
<div style="width: 50%;">
|
2287 |
<img alt="image.png" src="/assets/images/fused_kernels2.png" style="width: 100%;" />
|
2288 |
+
<div class="figure-legend">
|
2289 |
+
<p>Instead of sending our triangle back to global memory just to read it back again, we instead just do all of our operations in one go.</p>
|
2290 |
+
</div>
|
2291 |
</div>
|
2292 |
</div>
|
2293 |
|
2294 |
<p>How can we avoid this back and forth? Well the best way is to make our GPU as autonomous as possible. This is achieved by packing as many successive compute operations together in a single kernel for the GPU to run, called a “Fused Kernel”.</p>
|
2295 |
|
2296 |
|
2297 |
+
<p>Fused kernel are especially efficient and simple to write for succession of point-like operations which are performed independently of each other on each input tokens. In this case, there is no point in bringing back computed values in Global Memory before moving them to SM memory and spinning up a new kernel. It’s much more efficient to keep all values locally until the succession of computation has been performed.</p>
|
2298 |
|
2299 |
+
<p>There are many places in a Transformer model where this "fusing" approach can be applied: every time we have a succession of point-wise operations e.g. in the computation involved in the Layer norms.</p>
|
2300 |
|
2301 |
<p>We now have all the understanding necessary to marvel at a true masterpiece of kernel engineering: <strong><em>Flash Attention</em></strong></p>
|
2302 |
|
2303 |
<h3>Flash Attention 1-3</h3>
|
2304 |
|
2305 |
+
<p>Flash attention was introduced by <a href="https://tridao.me">Tri Dao</a> and proposed to optimize the attention computations by writing custom CUDA kernels make them much faster *and* more memory efficient. The idea behind Flash Attention is to make efficient use of the various memories of the GPU to avoid relying too much on the slowest one: the global memory of the GPU.</p>
|
2306 |
+
<aside>Note that the global memory of the GPU is confusingly called the "High Bandwidth Memory", HBM 🫠</aside>
|
2307 |
|
2308 |
<p>A basic implementation of the attention mechanism involve a lot of transfer between memory and workers. It requires materializing the S and P matrices in HBM which means that the results need to be sent to HBM and then back to SRAM for the next computations:</p>
|
2309 |
|
|
|
2330 |
|
2331 |
<p>Flash-Attention is a master demonstration of the breakthrough improvements that can come when you take into account the internal memory/compute design of current GPU accelerators.</p>
|
2332 |
|
2333 |
+
<hr>
|
2334 |
+
|
2335 |
+
<p>The techniques described so far in this operation-fusion section have required us to implement modeling code changes and write custom kernels for certain operations in order to speed up training.</p>
|
2336 |
+
<p>In the final section of our low-level dive in the compute operations themselves, we will take a look at a range of methods that are agnostic to the modeling code and can be used for any model and are so widely used that they have become a standard in the industry: <strong>Mixed Precision Training</strong>!</p>
|
2337 |
|
2338 |
<h3>Mixed Precision Training</h3>
|
2339 |
+
|
2340 |
+
<p>In various sections along this book, we've talked about lower precisions formats and their impact on the memory requirements for storing activations, parameters and optimizer states. It's now time to dive deeper in the details of these formats and understand better their trade-offs, advantages and limitations.</p>
|
2341 |
|
2342 |
<p>Mixed Precision Training, as the name suggests, involves mixing different precisions when training. The default numerical precision of PyTorch tensors is single-precision floating point format or also called FP32 or float32 which means that every number stored takes up 32 bits or 4 bytes. The available bits to represent a number are divided into 3 parts:</p>
|
2343 |
|
|
|
2347 |
<li>Exponent: controls the magnitude of the number</li>
|
2348 |
</ul>
|
2349 |
|
2350 |
+
<p><img width="500px" alt="sign-mantissa-exponent.svg" src="/assets/images/sign-mantissa-exponent.svg" /></p>
|
2351 |
+
|
2352 |
+
|
2353 |
+
|
2354 |
<p>The principle of floating point numbers can be easily illustrated by recalling the scientific notation of numbers, e.g. <d-math>- 5.734 \times 10^{7}</d-math>, where we first have the sign, followed by the mantissa an the exponent. As such we can represent numbers across a wide range of magnitudes with an adaptive precision. Although float32 is the default there is a range of floating point formats available in PyTorch:</p>
|
2355 |
|
2356 |
<p></p>
|
|
|
2361 |
<th><strong>Format</strong></th>
|
2362 |
<th><strong>Total bits</strong></th>
|
2363 |
<th><strong>Sign</strong></th>
|
|
|
2364 |
<th><strong>Exponent</strong></th>
|
2365 |
+
<th><strong>Mantissa</strong></th>
|
2366 |
</tr>
|
2367 |
</thead>
|
2368 |
<tbody>
|
|
|
2370 |
<td>float32</td>
|
2371 |
<td>32</td>
|
2372 |
<td>1</td>
|
|
|
2373 |
<td>8</td>
|
2374 |
+
<td>23</td>
|
2375 |
</tr>
|
2376 |
<tr>
|
2377 |
<td>float16</td>
|
2378 |
<td>16</td>
|
2379 |
<td>1</td>
|
|
|
2380 |
<td>5</td>
|
2381 |
+
<td>10</td>
|
2382 |
</tr>
|
2383 |
<tr>
|
2384 |
<td>bfloat16</td>
|
2385 |
<td>16</td>
|
2386 |
<td>1</td>
|
|
|
2387 |
<td>8</td>
|
2388 |
+
<td>7</td>
|
2389 |
</tr>
|
2390 |
<tr>
|
2391 |
<td>float8 (e4m3)</td>
|
2392 |
<td>8</td>
|
2393 |
<td>1</td>
|
|
|
2394 |
<td>4</td>
|
2395 |
+
<td>3</td>
|
2396 |
</tr>
|
2397 |
<tr>
|
2398 |
<td>float8 (e5m2)</td>
|
2399 |
<td>8</td>
|
2400 |
<td>1</td>
|
|
|
2401 |
<td>5</td>
|
2402 |
+
<td>2</td>
|
2403 |
</tr>
|
2404 |
</tbody>
|
2405 |
</table>
|
|
|
2419 |
|
2420 |
<p>We can see here that bfloat16 maintained the range of float32 over float16 but did this with the cost of sacrificing more precision. In case of float8 the situation is even more dire as e4m3 can represent 7 and e5m2 only 3 number on the interval 1-2.</p>
|
2421 |
|
2422 |
+
<p>A common metric to measure a formats resolution is epsilon: the first representable number after <d-math>1.00</d-math>. We can see that for the float32 format <d-math>10^{-4}</d-math> is an upper bound (it’s actually <d-math>1.19^{-7}</d-math>). For float16 it is <d-math>\tilde 10^{-3}</d-math> and for bfloat 10x higher still.</p>
|
2423 |
|
2424 |
+
<p>The idea of mixed precision training is to use some of these lower precisions formats while maintaining the performance of full precision training. </p>
|
2425 |
|
2426 |
+
<p>It turns out we <strong>can’t</strong> totally abandon float32 and usually will need to maintain some parts in full precision. This is why lower precision training is usually called <strong><em>mixed precision</em></strong> training. </p>
|
2427 |
|
2428 |
<p>Let’s now take a look at training models with 16 bits and then see if we can take it a step further all the way down to 8 bits.</p>
|
2429 |
|
|
|
2436 |
<ol>
|
2437 |
<li><strong>FP32 copy of weights</strong>: There are two possible issues with float16 weights. During training some of the weights can become very small and will be rounded to 0. However, even if the weights themselves are not close to zero, if the updates are very small the difference in magnitude can cause the weights to underflow during the addition. Once the weights are zero they will remain 0 for the rest of training as there is no gradient signal coming through anymore.</li>
|
2438 |
<li><strong>Loss scaling</strong>: We have a similar issue with the gradients as well as gradients tend to be much smaller than 1 and are thus at risk to underflow. A simple, yet effective, strategy is to scale the loss before the backward pass and unscale the gradients after the backward pass. This ensures that there is no underflow during the backward pass and the scaling is not affecting training as we unscale before processing the gradients further (e.g. clipping) and the optimization step. </li>
|
2439 |
+
<li><strong>Accumulation</strong>: Finally, when performing certain arithmetic operations in 16-bit precision such as averages or summations, we can also face under or overflows. A solution is then to accumulate intermediate results in float32 during the operation and only cast the final result back to 16 bit precision.</li>
|
2440 |
</ol>
|
2441 |
|
2442 |
+
<p>With these techniques, we can get a stable training while benefitting from a higher throughput due to the faster, lower precision arithmetic operations. Naturally, as a curious reader –and by now slightly addicted to maximizing the throughput– you may ask the question: can we go further and faster than 16-bit precision? </p>
|
2443 |
|
2444 |
<p>Maybe!</p>
|
2445 |
|
|
|
2451 |
|
2452 |
<p>We know that instability increases as learning rates rise for a fixed model size<d-cite bibtex-key="wortsman2023smallscaleproxieslargescaletransformer"></d-cite>, making FP8 pretraining particularly tricky.</p>
|
2453 |
|
2454 |
+
<p>Here is an example of a typically divergent loss curve for FP8 training:</p>
|
2455 |
<iframe class="l-body-outset" id="plotFP8Loss" src="/assets/data/fp8/fp8_training_loss_curves.html" height="520" width="1000" scrolling="no" frameborder="0"></iframe>
|
2456 |
<!-- Hynek uncomment this once it's added to -->
|
2457 |
<!-- <div class="l-body-outset" id="fragment-fp8_training_loss_curves"></div> -->
|
|
|
2460 |
|
2461 |
<p><img alt="image.png" src="/assets/images/fp8_diagram.png" /></p>
|
2462 |
|
2463 |
+
<p>In order to switch from high precision (e.g. FP32 or BF16) to lower precision (e.g. FP16 or FP8) with smaller range, we need to normalize the range of activation values, for instance by computing their absolute maximum. DeepSeek-V3 further introduced a specific quantization scheme where the ranges are normalized per tile: 1x128 for inputs/activations and 128x128 for weights and scale elements. This makes the normalization less strongly impacted by outlier values in the activations. There is a number of additional tricks they proposed to further reduce the memory and communication footprint which you can follow in section 3.3. of the DeepSeek-V3 technical report<d-cite bibtex-key="deepseekai2024deepseekv3technicalreport"></d-cite>. </p>
|
2464 |
|
2465 |
<p>Here’s a summary of a few known approaches to FP8 training:</p>
|
2466 |
|
|
|
2541 |
</tbody>
|
2542 |
</table>
|
2543 |
|
2544 |
+
<p>Overall, FP8 remains –in early 2025– an experimental technique and methods are still evolving. Given its obvious benefits, it will likely become the standard and soon replace bf16 mixed-precision. To follow an open-source implementations of FP8 training techniques, please head to the nanotron’s implementation in <a href="https://github.com/huggingface/nanotron/pull/70">this PR</a>. </p>
|
2545 |
|
2546 |
+
<p>Projecting further into the future, Blackwell, the next generation of NVIDIA chips, <a href="https://www.nvidia.com/en-us/data-center/technologies/blackwell-architecture/">have been announced </a> to support FP4 training, further speeding up training but without a doubt also introducing a new training stability challenge.</p>
|
2547 |
+
|
2548 |
+
<hr>
|
2549 |
|
2550 |
+
<p>This last section concluded our long journey in the land of fast and large model training on tens to thousands of GPUs. Time to slowly bring our GPU cluster to rest and take a step back to conclude on all we've learned along the way.</p>
|
2551 |
|
2552 |
<h2>Conclusion</h2>
|
2553 |
|