Spaces:
Running
Running
app3 (#62)
Browse files- dist/index.html +196 -6
- src/index.html +196 -6
dist/index.html
CHANGED
@@ -1168,17 +1168,17 @@
|
|
1168 |
<p style="margin-bottom: 0;"><strong>First Transition (SP → TP)</strong></p>
|
1169 |
<ul style="margin-top: 0;">
|
1170 |
<li>"g" operation (all-gather) combines Y1<em> and Y2</em> back to full sequence length</li>
|
1171 |
-
<li> Restores Y (b,s,h) since column linear
|
1172 |
</ul>
|
1173 |
-
<p style="margin-bottom: 0;"><strong>First Linear
|
1174 |
<ul style="margin-top: 0;">
|
1175 |
-
<li>A1 is a column-linear
|
1176 |
<li>GeLU is applied independently on each GPU</li>
|
1177 |
<li>Z1* is (b,s,h/2)</li>
|
1178 |
</ul>
|
1179 |
-
<p style="margin-bottom: 0;"><strong>Second Linear
|
1180 |
<ul style="margin-top: 0;">
|
1181 |
-
<li>B1 is a row-linear
|
1182 |
<li>W1 is (b,s,h)</li>
|
1183 |
</ul>
|
1184 |
<p style="margin-bottom: 0;"><strong>Final Transition (TP → SP)</strong></p>
|
@@ -3491,9 +3491,199 @@
|
|
3491 |
|
3492 |
<p>Using this method, you can profile the custom CUDA kernel just as we demonstrated earlier with PyTorch's profiler or NVIDIA tools.</p>
|
3493 |
|
3494 |
-
<h3>A2:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3495 |
|
3496 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3497 |
</d-article>
|
3498 |
|
3499 |
<d-appendix>
|
|
|
1168 |
<p style="margin-bottom: 0;"><strong>First Transition (SP → TP)</strong></p>
|
1169 |
<ul style="margin-top: 0;">
|
1170 |
<li>"g" operation (all-gather) combines Y1<em> and Y2</em> back to full sequence length</li>
|
1171 |
+
<li> Restores Y (b,s,h) since column linear needs full hidden dimension h</li>
|
1172 |
</ul>
|
1173 |
+
<p style="margin-bottom: 0;"><strong>First Linear (TP Region)</strong></p>
|
1174 |
<ul style="margin-top: 0;">
|
1175 |
+
<li>A1 is a column-linear, so it splits Y along the hidden dimension</li>
|
1176 |
<li>GeLU is applied independently on each GPU</li>
|
1177 |
<li>Z1* is (b,s,h/2)</li>
|
1178 |
</ul>
|
1179 |
+
<p style="margin-bottom: 0;"><strong>Second Linear (TP Region)</strong></p>
|
1180 |
<ul style="margin-top: 0;">
|
1181 |
+
<li>B1 is a row-linear, so it restores the hidden dimension</li>
|
1182 |
<li>W1 is (b,s,h)</li>
|
1183 |
</ul>
|
1184 |
<p style="margin-bottom: 0;"><strong>Final Transition (TP → SP)</strong></p>
|
|
|
3491 |
|
3492 |
<p>Using this method, you can profile the custom CUDA kernel just as we demonstrated earlier with PyTorch's profiler or NVIDIA tools.</p>
|
3493 |
|
3494 |
+
<h3>A2: Typical Scales in LLM Training</h3>
|
3495 |
+
|
3496 |
+
<p>Let's get a feel for the typical sizes of things in LLM training. When we talk about memory or compute, we're often counting "elements" - think of these as numbers in tensors. To get the actual memory in bytes, you'll need to multiply by the size of each number (e.g., 2 bytes for bf16, 4 bytes for fp32).</p>
|
3497 |
+
|
3498 |
+
<p>Here are some quick ballpark figures:</p>
|
3499 |
+
|
3500 |
+
<ul>
|
3501 |
+
<li><strong>Input tokens:</strong> For each batch, we process <d-math>seq \cdot mbs</d-math> tokens, where mbs is the micro batch size and seq is the sequence length.</li>
|
3502 |
+
|
3503 |
+
<li><strong>Activations (hidden states):</strong> For a single layer, the hidden state tensor is of size <d-math>seq \cdot mbs \cdot h</d-math> elements.</li>
|
3504 |
+
|
3505 |
+
<li><strong>Model weights and gradients:</strong> Each weight matrix in your model (like in linears) is about <d-math>h^2</d-math> elements. This is per weight matrix. Gradients have the same size as weights.</li>
|
3506 |
+
|
3507 |
+
<li><strong>Optimizer states:</strong> For each weight matrix (of elements <d-math>h^2</d-math>), if you're using an optimizer like Adam with mixed precision training, it keeps momentum and variance states in fp32 precision (<d-math>2 \cdot h^2</d-math>), plus master weights in fp32 (<d-math>h^2</d-math>). So total optimizer states will be around (<d-math>6 \cdot h^2</d-math>) per weight matrix.</li>
|
3508 |
+
|
3509 |
+
<li><strong>Total model parameters:</strong> For each transformer block:
|
3510 |
+
<ul>
|
3511 |
+
<li>Attention parameters:
|
3512 |
+
<ul>
|
3513 |
+
<li>QKV projections: <d-math>3h^2</d-math> parameters</li>
|
3514 |
+
<li>Output projection: <d-math>h^2</d-math> parameters</li>
|
3515 |
+
</ul>
|
3516 |
+
</li>
|
3517 |
+
<li>MLP parameters with GLU:
|
3518 |
+
<ul>
|
3519 |
+
<li>Gate and up projections: <d-math>8h^2</d-math> parameters (2 matrices of size <d-math>h \times 4h</d-math>)</li>
|
3520 |
+
<li>Down projection: <d-math>4h^2</d-math> parameters (1 matrix of size <d-math>4h \times h</d-math>)</li>
|
3521 |
+
</ul>
|
3522 |
+
</li>
|
3523 |
+
<li>Total per block: <d-math>16h^2</d-math> with GLU MLPs, or <d-math>12h^2</d-math> without GLU</li>
|
3524 |
+
<li>For full model: <d-math>16h^2 \cdot num\_layers</d-math> (with GLU)</li>
|
3525 |
+
<li>Additional parameters:
|
3526 |
+
<ul>
|
3527 |
+
<li>Input embeddings: <d-math>vocab\_size \cdot h</d-math></li>
|
3528 |
+
<li>LM head: <d-math>vocab\_size \cdot h</d-math> (if not tied with input embeddings)</li>
|
3529 |
+
<li>Positional embeddings (if used): <d-math>max\_seq\_len \cdot h</d-math></li>
|
3530 |
+
</ul>
|
3531 |
+
</li>
|
3532 |
+
</ul>
|
3533 |
+
</li>
|
3534 |
+
|
3535 |
+
<li><strong>Forward and backward pass compute (FLOPs):</strong> A very rough estimate for the FLOPs in a forward pass is <d-math>2 \cdot num\_tokens \cdot num\_params</d-math>. And backward pass compute is twice as that: <d-math>4 \cdot num\_tokens \cdot num\_params</d-math>.</li>
|
3536 |
+
</ul>
|
3537 |
+
|
3538 |
+
<h3>A3: Math for Compute/Communication Overlap</h3>
|
3539 |
+
|
3540 |
+
<p>Using the formulas from the previous section, we can estimate when computation and communication can effectively overlap in distributed training. Let's look at data parallelism (Zero-0) as an example.</p>
|
3541 |
+
|
3542 |
+
<h4>Data Parallelism Communication Analysis</h4>
|
3543 |
+
|
3544 |
+
<p>The total gradient size that needs to be communicated is:</p>
|
3545 |
+
<ul>
|
3546 |
+
<li>Gradients = Parameters ≈ <d-math>num\_layers \cdot 16h^2</d-math></li>
|
3547 |
+
</ul>
|
3548 |
+
|
3549 |
+
<p>During backward pass, these gradients are communicated in buckets (default 25MB). The communication time for each bucket is:</p>
|
3550 |
+
|
3551 |
+
<d-math block>
|
3552 |
+
t_{comm} = t_{comm\_bucket} = \frac{bucket\_size \cdot 2(DP-1)}{DP \cdot peak\_bw}
|
3553 |
+
</d-math>
|
3554 |
+
|
3555 |
+
<p>The computation time for backward pass is:</p>
|
3556 |
+
|
3557 |
+
<d-math block>
|
3558 |
+
t_{compute} = \frac{4 \cdot num\_tokens \cdot num\_params}{peak\_flops}
|
3559 |
+
</d-math>
|
3560 |
+
|
3561 |
+
<p>For effective overlap, we need:</p>
|
3562 |
+
|
3563 |
+
<d-math block>
|
3564 |
+
\frac{t_{comm}}{t_{compute}} = \frac{num\_params}{2 \cdot num\_tokens} \cdot \frac{DP-1}{DP} \cdot \frac{peak\_flops}{peak\_bw} \leq 1
|
3565 |
+
</d-math>
|
3566 |
+
|
3567 |
+
<p>This ratio helps determine if communication will become a bottleneck in training. When the ratio is less than 1, communication can be fully overlapped with computation.</p>
|
3568 |
+
|
3569 |
+
<h4>ZeRO-3 (FSDP) Communication Analysis</h4>
|
3570 |
+
|
3571 |
+
<p>For ZeRO-3, parameters and gradients are sharded across GPUs. Let's analyze the communication pattern for a model with transformer blocks of size <d-math>16h^2</d-math> parameters each:</p>
|
3572 |
+
|
3573 |
+
<ul>
|
3574 |
+
<li>For each transformer block in forward pass:
|
3575 |
+
<ul>
|
3576 |
+
<li>Allgather parameters: <d-math>16h^2/DP</d-math> bytes per rank</li>
|
3577 |
+
</ul>
|
3578 |
+
</li>
|
3579 |
+
<li>For each transformer block in backward pass:
|
3580 |
+
<ul>
|
3581 |
+
<li>Allgather parameters: <d-math>16h^2/DP</d-math> bytes per rank</li>
|
3582 |
+
<li>Reducescatter gradients: <d-math>16h^2/DP</d-math> bytes per rank</li>
|
3583 |
+
</ul>
|
3584 |
+
</li>
|
3585 |
+
<li>Total communication per block: <d-math>3 \cdot 16h^2/DP</d-math> bytes</li>
|
3586 |
+
<li>Total communication for full model: <d-math>3 \cdot num\_layers \cdot 16h^2/DP</d-math> bytes</li>
|
3587 |
+
</ul>
|
3588 |
+
|
3589 |
+
<p>The communication time for allgather operations is:</p>
|
3590 |
+
|
3591 |
+
<d-math block>
|
3592 |
+
t_{comm} = 16h^2 \cdot \frac{DP-1}{DP \cdot peak\_bw}
|
3593 |
+
</d-math>
|
3594 |
+
|
3595 |
+
<p>The computation time for forward pass of one decoder layer is:</p>
|
3596 |
+
<d-math block>
|
3597 |
+
t_{compute} = \frac{32 \cdot seq\_len \cdot mbs \cdot h^2}{peak\_flops}
|
3598 |
+
</d-math>
|
3599 |
+
|
3600 |
+
<p>For effective overlap between computation and communication, we need:</p>
|
3601 |
+
|
3602 |
+
<d-math block>
|
3603 |
+
\frac{t_{comm}}{t_{compute}} = \frac{1}{2 \cdot seq\_len \cdot mbs} \cdot \frac{DP-1}{DP} \cdot \frac{peak\_flops}{peak\_bw} \leq 1
|
3604 |
+
</d-math>
|
3605 |
+
|
3606 |
+
<p>When this ratio is less than 1, the communication of parameters for the next layer can be hidden behind the computation of the current layer.</p>
|
3607 |
+
`
|
3608 |
+
<h4>TP Communication Analysis</h4>
|
3609 |
+
|
3610 |
+
<p>For Tensor Parallel (TP), activations are sharded across GPUs during linears. Let's analyze the communication pattern:</p>
|
3611 |
+
|
3612 |
+
<ul>
|
3613 |
+
<li>For each column linear in forward pass:
|
3614 |
+
<ul>
|
3615 |
+
<li>Allgather activations: <d-math>seq \cdot mbs \cdot h/TP</d-math> bytes per rank</li>
|
3616 |
+
</ul>
|
3617 |
+
</li>
|
3618 |
+
<li>For each column linear in backward pass:
|
3619 |
+
<ul>
|
3620 |
+
<li>Reducescatter gradients: <d-math>seq \cdot mbs \cdot h/TP</d-math> bytes per rank</li>
|
3621 |
+
</ul>
|
3622 |
+
</li>
|
3623 |
+
<li>And vice-versa for row linears. Each transformer block has 2 column linears and 2 row linears.</li>
|
3624 |
+
<li>Total communication per block: <d-math>8 \cdot seq \cdot mbs \cdot h/TP</d-math> bytes</li>
|
3625 |
+
<li>Total communication for full model: <d-math>8 \cdot num\_layers \cdot seq \cdot mbs \cdot h/TP</d-math> bytes</li>
|
3626 |
+
</ul>
|
3627 |
+
<p>Let's analyze if we can overlap the allgather communication for one layer with the computation of the next linear. The communication time for allgather operations is:</p>
|
3628 |
+
|
3629 |
+
<d-math block>
|
3630 |
+
t_{comm} = \frac{seq \cdot mbs \cdot h \cdot (TP-1)}{TP \cdot peak\_bw}
|
3631 |
+
</d-math>
|
3632 |
+
|
3633 |
+
<p>While the computation time for the next linear (with parameters <d-math>h^2</d-math>) is:</p>
|
3634 |
+
|
3635 |
+
<d-math block>
|
3636 |
+
t_{compute} = \frac{2 \cdot seq \cdot mbs \cdot h^2}{TP \cdot peak\_flops}
|
3637 |
+
</d-math>
|
3638 |
+
|
3639 |
+
<p>For effective overlap, we want the communication time to be less than the compute time:</p>
|
3640 |
+
<d-math block>
|
3641 |
+
\frac{t_{comm}}{t_{compute}} = \frac{TP-1}{2 \cdot h} \cdot \frac{peak\_flops}{peak\_bw} \leq 1
|
3642 |
+
</d-math>
|
3643 |
+
|
3644 |
+
<p>This ratio tells us whether we can successfully hide the allgather communication behind the computation of the next linear. Interestingly, the ratio only depends on the hidden size h and tensor parallelism degree TP, not on sequence length or batch size.</p>
|
3645 |
|
3646 |
|
3647 |
+
<h4>PP Communication Analysis</h4>
|
3648 |
+
|
3649 |
+
<p>For Pipeline Parallel (PP), activations and gradients are communicated between pipeline stages. Let's analyze the communication pattern:</p>
|
3650 |
+
|
3651 |
+
<ul>
|
3652 |
+
<li>For each microbatch in forward pass:
|
3653 |
+
<ul>
|
3654 |
+
<li>Receive and send activations: <d-math>2 \cdot seq \cdot mbs \cdot h</d-math> bytes</li>
|
3655 |
+
</ul>
|
3656 |
+
</li>
|
3657 |
+
<li>For each microbatch in backward pass:
|
3658 |
+
<ul>
|
3659 |
+
<li>Receive and send gradients: <d-math>2 \cdot seq \cdot mbs \cdot h</d-math> bytes</li>
|
3660 |
+
</ul>
|
3661 |
+
</li>
|
3662 |
+
<li>Total communication per microbatch: <d-math>4 \cdot seq \cdot mbs \cdot h</d-math> bytes</li>
|
3663 |
+
<li>For gradient accumulation steps (gas), total communication: <d-math>4 \cdot gas \cdot seq \cdot mbs \cdot h</d-math> bytes</li>
|
3664 |
+
</ul>
|
3665 |
+
|
3666 |
+
<p>Let's analyze if we can overlap the communication of activations/gradients with computation of the next transformer block. The computation time for transformer blocks in the next pipeline stage is:</p>
|
3667 |
+
|
3668 |
+
<d-math block>
|
3669 |
+
t_{compute} = \frac{32 \cdot seq \cdot mbs \cdot h^2 \cdot num\_layers\_in\_next\_pp}{peak\_flops}
|
3670 |
+
</d-math>
|
3671 |
+
|
3672 |
+
<p>While the communication time for P2P transfer is:</p>
|
3673 |
+
|
3674 |
+
<d-math block>
|
3675 |
+
t_{comm} = \frac{seq \cdot mbs \cdot h}{peak\_bw}
|
3676 |
+
</d-math>
|
3677 |
+
|
3678 |
+
<p>For effective overlap, we want:</p>
|
3679 |
+
|
3680 |
+
<d-math block>
|
3681 |
+
\frac{t_{comm}}{t_{compute}} = \frac{peak\_flops}{32 \cdot h \cdot num\_layers\_in\_next\_pp \cdot peak\_bw} \leq 1
|
3682 |
+
</d-math>
|
3683 |
+
|
3684 |
+
<p>Similar to TP, this ratio is independent of sequence length and batch size. It depends on the hidden size h, number of layers in the next pipeline stage, and the ratio of compute to P2P bandwidth capabilities of the hardware.</p>
|
3685 |
+
|
3686 |
+
|
3687 |
</d-article>
|
3688 |
|
3689 |
<d-appendix>
|
src/index.html
CHANGED
@@ -1168,17 +1168,17 @@
|
|
1168 |
<p style="margin-bottom: 0;"><strong>First Transition (SP → TP)</strong></p>
|
1169 |
<ul style="margin-top: 0;">
|
1170 |
<li>"g" operation (all-gather) combines Y1<em> and Y2</em> back to full sequence length</li>
|
1171 |
-
<li> Restores Y (b,s,h) since column linear
|
1172 |
</ul>
|
1173 |
-
<p style="margin-bottom: 0;"><strong>First Linear
|
1174 |
<ul style="margin-top: 0;">
|
1175 |
-
<li>A1 is a column-linear
|
1176 |
<li>GeLU is applied independently on each GPU</li>
|
1177 |
<li>Z1* is (b,s,h/2)</li>
|
1178 |
</ul>
|
1179 |
-
<p style="margin-bottom: 0;"><strong>Second Linear
|
1180 |
<ul style="margin-top: 0;">
|
1181 |
-
<li>B1 is a row-linear
|
1182 |
<li>W1 is (b,s,h)</li>
|
1183 |
</ul>
|
1184 |
<p style="margin-bottom: 0;"><strong>Final Transition (TP → SP)</strong></p>
|
@@ -3491,9 +3491,199 @@
|
|
3491 |
|
3492 |
<p>Using this method, you can profile the custom CUDA kernel just as we demonstrated earlier with PyTorch's profiler or NVIDIA tools.</p>
|
3493 |
|
3494 |
-
<h3>A2:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3495 |
|
3496 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3497 |
</d-article>
|
3498 |
|
3499 |
<d-appendix>
|
|
|
1168 |
<p style="margin-bottom: 0;"><strong>First Transition (SP → TP)</strong></p>
|
1169 |
<ul style="margin-top: 0;">
|
1170 |
<li>"g" operation (all-gather) combines Y1<em> and Y2</em> back to full sequence length</li>
|
1171 |
+
<li> Restores Y (b,s,h) since column linear needs full hidden dimension h</li>
|
1172 |
</ul>
|
1173 |
+
<p style="margin-bottom: 0;"><strong>First Linear (TP Region)</strong></p>
|
1174 |
<ul style="margin-top: 0;">
|
1175 |
+
<li>A1 is a column-linear, so it splits Y along the hidden dimension</li>
|
1176 |
<li>GeLU is applied independently on each GPU</li>
|
1177 |
<li>Z1* is (b,s,h/2)</li>
|
1178 |
</ul>
|
1179 |
+
<p style="margin-bottom: 0;"><strong>Second Linear (TP Region)</strong></p>
|
1180 |
<ul style="margin-top: 0;">
|
1181 |
+
<li>B1 is a row-linear, so it restores the hidden dimension</li>
|
1182 |
<li>W1 is (b,s,h)</li>
|
1183 |
</ul>
|
1184 |
<p style="margin-bottom: 0;"><strong>Final Transition (TP → SP)</strong></p>
|
|
|
3491 |
|
3492 |
<p>Using this method, you can profile the custom CUDA kernel just as we demonstrated earlier with PyTorch's profiler or NVIDIA tools.</p>
|
3493 |
|
3494 |
+
<h3>A2: Typical Scales in LLM Training</h3>
|
3495 |
+
|
3496 |
+
<p>Let's get a feel for the typical sizes of things in LLM training. When we talk about memory or compute, we're often counting "elements" - think of these as numbers in tensors. To get the actual memory in bytes, you'll need to multiply by the size of each number (e.g., 2 bytes for bf16, 4 bytes for fp32).</p>
|
3497 |
+
|
3498 |
+
<p>Here are some quick ballpark figures:</p>
|
3499 |
+
|
3500 |
+
<ul>
|
3501 |
+
<li><strong>Input tokens:</strong> For each batch, we process <d-math>seq \cdot mbs</d-math> tokens, where mbs is the micro batch size and seq is the sequence length.</li>
|
3502 |
+
|
3503 |
+
<li><strong>Activations (hidden states):</strong> For a single layer, the hidden state tensor is of size <d-math>seq \cdot mbs \cdot h</d-math> elements.</li>
|
3504 |
+
|
3505 |
+
<li><strong>Model weights and gradients:</strong> Each weight matrix in your model (like in linears) is about <d-math>h^2</d-math> elements. This is per weight matrix. Gradients have the same size as weights.</li>
|
3506 |
+
|
3507 |
+
<li><strong>Optimizer states:</strong> For each weight matrix (of elements <d-math>h^2</d-math>), if you're using an optimizer like Adam with mixed precision training, it keeps momentum and variance states in fp32 precision (<d-math>2 \cdot h^2</d-math>), plus master weights in fp32 (<d-math>h^2</d-math>). So total optimizer states will be around (<d-math>6 \cdot h^2</d-math>) per weight matrix.</li>
|
3508 |
+
|
3509 |
+
<li><strong>Total model parameters:</strong> For each transformer block:
|
3510 |
+
<ul>
|
3511 |
+
<li>Attention parameters:
|
3512 |
+
<ul>
|
3513 |
+
<li>QKV projections: <d-math>3h^2</d-math> parameters</li>
|
3514 |
+
<li>Output projection: <d-math>h^2</d-math> parameters</li>
|
3515 |
+
</ul>
|
3516 |
+
</li>
|
3517 |
+
<li>MLP parameters with GLU:
|
3518 |
+
<ul>
|
3519 |
+
<li>Gate and up projections: <d-math>8h^2</d-math> parameters (2 matrices of size <d-math>h \times 4h</d-math>)</li>
|
3520 |
+
<li>Down projection: <d-math>4h^2</d-math> parameters (1 matrix of size <d-math>4h \times h</d-math>)</li>
|
3521 |
+
</ul>
|
3522 |
+
</li>
|
3523 |
+
<li>Total per block: <d-math>16h^2</d-math> with GLU MLPs, or <d-math>12h^2</d-math> without GLU</li>
|
3524 |
+
<li>For full model: <d-math>16h^2 \cdot num\_layers</d-math> (with GLU)</li>
|
3525 |
+
<li>Additional parameters:
|
3526 |
+
<ul>
|
3527 |
+
<li>Input embeddings: <d-math>vocab\_size \cdot h</d-math></li>
|
3528 |
+
<li>LM head: <d-math>vocab\_size \cdot h</d-math> (if not tied with input embeddings)</li>
|
3529 |
+
<li>Positional embeddings (if used): <d-math>max\_seq\_len \cdot h</d-math></li>
|
3530 |
+
</ul>
|
3531 |
+
</li>
|
3532 |
+
</ul>
|
3533 |
+
</li>
|
3534 |
+
|
3535 |
+
<li><strong>Forward and backward pass compute (FLOPs):</strong> A very rough estimate for the FLOPs in a forward pass is <d-math>2 \cdot num\_tokens \cdot num\_params</d-math>. And backward pass compute is twice as that: <d-math>4 \cdot num\_tokens \cdot num\_params</d-math>.</li>
|
3536 |
+
</ul>
|
3537 |
+
|
3538 |
+
<h3>A3: Math for Compute/Communication Overlap</h3>
|
3539 |
+
|
3540 |
+
<p>Using the formulas from the previous section, we can estimate when computation and communication can effectively overlap in distributed training. Let's look at data parallelism (Zero-0) as an example.</p>
|
3541 |
+
|
3542 |
+
<h4>Data Parallelism Communication Analysis</h4>
|
3543 |
+
|
3544 |
+
<p>The total gradient size that needs to be communicated is:</p>
|
3545 |
+
<ul>
|
3546 |
+
<li>Gradients = Parameters ≈ <d-math>num\_layers \cdot 16h^2</d-math></li>
|
3547 |
+
</ul>
|
3548 |
+
|
3549 |
+
<p>During backward pass, these gradients are communicated in buckets (default 25MB). The communication time for each bucket is:</p>
|
3550 |
+
|
3551 |
+
<d-math block>
|
3552 |
+
t_{comm} = t_{comm\_bucket} = \frac{bucket\_size \cdot 2(DP-1)}{DP \cdot peak\_bw}
|
3553 |
+
</d-math>
|
3554 |
+
|
3555 |
+
<p>The computation time for backward pass is:</p>
|
3556 |
+
|
3557 |
+
<d-math block>
|
3558 |
+
t_{compute} = \frac{4 \cdot num\_tokens \cdot num\_params}{peak\_flops}
|
3559 |
+
</d-math>
|
3560 |
+
|
3561 |
+
<p>For effective overlap, we need:</p>
|
3562 |
+
|
3563 |
+
<d-math block>
|
3564 |
+
\frac{t_{comm}}{t_{compute}} = \frac{num\_params}{2 \cdot num\_tokens} \cdot \frac{DP-1}{DP} \cdot \frac{peak\_flops}{peak\_bw} \leq 1
|
3565 |
+
</d-math>
|
3566 |
+
|
3567 |
+
<p>This ratio helps determine if communication will become a bottleneck in training. When the ratio is less than 1, communication can be fully overlapped with computation.</p>
|
3568 |
+
|
3569 |
+
<h4>ZeRO-3 (FSDP) Communication Analysis</h4>
|
3570 |
+
|
3571 |
+
<p>For ZeRO-3, parameters and gradients are sharded across GPUs. Let's analyze the communication pattern for a model with transformer blocks of size <d-math>16h^2</d-math> parameters each:</p>
|
3572 |
+
|
3573 |
+
<ul>
|
3574 |
+
<li>For each transformer block in forward pass:
|
3575 |
+
<ul>
|
3576 |
+
<li>Allgather parameters: <d-math>16h^2/DP</d-math> bytes per rank</li>
|
3577 |
+
</ul>
|
3578 |
+
</li>
|
3579 |
+
<li>For each transformer block in backward pass:
|
3580 |
+
<ul>
|
3581 |
+
<li>Allgather parameters: <d-math>16h^2/DP</d-math> bytes per rank</li>
|
3582 |
+
<li>Reducescatter gradients: <d-math>16h^2/DP</d-math> bytes per rank</li>
|
3583 |
+
</ul>
|
3584 |
+
</li>
|
3585 |
+
<li>Total communication per block: <d-math>3 \cdot 16h^2/DP</d-math> bytes</li>
|
3586 |
+
<li>Total communication for full model: <d-math>3 \cdot num\_layers \cdot 16h^2/DP</d-math> bytes</li>
|
3587 |
+
</ul>
|
3588 |
+
|
3589 |
+
<p>The communication time for allgather operations is:</p>
|
3590 |
+
|
3591 |
+
<d-math block>
|
3592 |
+
t_{comm} = 16h^2 \cdot \frac{DP-1}{DP \cdot peak\_bw}
|
3593 |
+
</d-math>
|
3594 |
+
|
3595 |
+
<p>The computation time for forward pass of one decoder layer is:</p>
|
3596 |
+
<d-math block>
|
3597 |
+
t_{compute} = \frac{32 \cdot seq\_len \cdot mbs \cdot h^2}{peak\_flops}
|
3598 |
+
</d-math>
|
3599 |
+
|
3600 |
+
<p>For effective overlap between computation and communication, we need:</p>
|
3601 |
+
|
3602 |
+
<d-math block>
|
3603 |
+
\frac{t_{comm}}{t_{compute}} = \frac{1}{2 \cdot seq\_len \cdot mbs} \cdot \frac{DP-1}{DP} \cdot \frac{peak\_flops}{peak\_bw} \leq 1
|
3604 |
+
</d-math>
|
3605 |
+
|
3606 |
+
<p>When this ratio is less than 1, the communication of parameters for the next layer can be hidden behind the computation of the current layer.</p>
|
3607 |
+
`
|
3608 |
+
<h4>TP Communication Analysis</h4>
|
3609 |
+
|
3610 |
+
<p>For Tensor Parallel (TP), activations are sharded across GPUs during linears. Let's analyze the communication pattern:</p>
|
3611 |
+
|
3612 |
+
<ul>
|
3613 |
+
<li>For each column linear in forward pass:
|
3614 |
+
<ul>
|
3615 |
+
<li>Allgather activations: <d-math>seq \cdot mbs \cdot h/TP</d-math> bytes per rank</li>
|
3616 |
+
</ul>
|
3617 |
+
</li>
|
3618 |
+
<li>For each column linear in backward pass:
|
3619 |
+
<ul>
|
3620 |
+
<li>Reducescatter gradients: <d-math>seq \cdot mbs \cdot h/TP</d-math> bytes per rank</li>
|
3621 |
+
</ul>
|
3622 |
+
</li>
|
3623 |
+
<li>And vice-versa for row linears. Each transformer block has 2 column linears and 2 row linears.</li>
|
3624 |
+
<li>Total communication per block: <d-math>8 \cdot seq \cdot mbs \cdot h/TP</d-math> bytes</li>
|
3625 |
+
<li>Total communication for full model: <d-math>8 \cdot num\_layers \cdot seq \cdot mbs \cdot h/TP</d-math> bytes</li>
|
3626 |
+
</ul>
|
3627 |
+
<p>Let's analyze if we can overlap the allgather communication for one layer with the computation of the next linear. The communication time for allgather operations is:</p>
|
3628 |
+
|
3629 |
+
<d-math block>
|
3630 |
+
t_{comm} = \frac{seq \cdot mbs \cdot h \cdot (TP-1)}{TP \cdot peak\_bw}
|
3631 |
+
</d-math>
|
3632 |
+
|
3633 |
+
<p>While the computation time for the next linear (with parameters <d-math>h^2</d-math>) is:</p>
|
3634 |
+
|
3635 |
+
<d-math block>
|
3636 |
+
t_{compute} = \frac{2 \cdot seq \cdot mbs \cdot h^2}{TP \cdot peak\_flops}
|
3637 |
+
</d-math>
|
3638 |
+
|
3639 |
+
<p>For effective overlap, we want the communication time to be less than the compute time:</p>
|
3640 |
+
<d-math block>
|
3641 |
+
\frac{t_{comm}}{t_{compute}} = \frac{TP-1}{2 \cdot h} \cdot \frac{peak\_flops}{peak\_bw} \leq 1
|
3642 |
+
</d-math>
|
3643 |
+
|
3644 |
+
<p>This ratio tells us whether we can successfully hide the allgather communication behind the computation of the next linear. Interestingly, the ratio only depends on the hidden size h and tensor parallelism degree TP, not on sequence length or batch size.</p>
|
3645 |
|
3646 |
|
3647 |
+
<h4>PP Communication Analysis</h4>
|
3648 |
+
|
3649 |
+
<p>For Pipeline Parallel (PP), activations and gradients are communicated between pipeline stages. Let's analyze the communication pattern:</p>
|
3650 |
+
|
3651 |
+
<ul>
|
3652 |
+
<li>For each microbatch in forward pass:
|
3653 |
+
<ul>
|
3654 |
+
<li>Receive and send activations: <d-math>2 \cdot seq \cdot mbs \cdot h</d-math> bytes</li>
|
3655 |
+
</ul>
|
3656 |
+
</li>
|
3657 |
+
<li>For each microbatch in backward pass:
|
3658 |
+
<ul>
|
3659 |
+
<li>Receive and send gradients: <d-math>2 \cdot seq \cdot mbs \cdot h</d-math> bytes</li>
|
3660 |
+
</ul>
|
3661 |
+
</li>
|
3662 |
+
<li>Total communication per microbatch: <d-math>4 \cdot seq \cdot mbs \cdot h</d-math> bytes</li>
|
3663 |
+
<li>For gradient accumulation steps (gas), total communication: <d-math>4 \cdot gas \cdot seq \cdot mbs \cdot h</d-math> bytes</li>
|
3664 |
+
</ul>
|
3665 |
+
|
3666 |
+
<p>Let's analyze if we can overlap the communication of activations/gradients with computation of the next transformer block. The computation time for transformer blocks in the next pipeline stage is:</p>
|
3667 |
+
|
3668 |
+
<d-math block>
|
3669 |
+
t_{compute} = \frac{32 \cdot seq \cdot mbs \cdot h^2 \cdot num\_layers\_in\_next\_pp}{peak\_flops}
|
3670 |
+
</d-math>
|
3671 |
+
|
3672 |
+
<p>While the communication time for P2P transfer is:</p>
|
3673 |
+
|
3674 |
+
<d-math block>
|
3675 |
+
t_{comm} = \frac{seq \cdot mbs \cdot h}{peak\_bw}
|
3676 |
+
</d-math>
|
3677 |
+
|
3678 |
+
<p>For effective overlap, we want:</p>
|
3679 |
+
|
3680 |
+
<d-math block>
|
3681 |
+
\frac{t_{comm}}{t_{compute}} = \frac{peak\_flops}{32 \cdot h \cdot num\_layers\_in\_next\_pp \cdot peak\_bw} \leq 1
|
3682 |
+
</d-math>
|
3683 |
+
|
3684 |
+
<p>Similar to TP, this ratio is independent of sequence length and batch size. It depends on the hidden size h, number of layers in the next pipeline stage, and the ratio of compute to P2P bandwidth capabilities of the hardware.</p>
|
3685 |
+
|
3686 |
+
|
3687 |
</d-article>
|
3688 |
|
3689 |
<d-appendix>
|