| --- | |
| title: Multipack (Sample Packing) | |
| description: Multipack is a technique to pack multiple sequences into a single batch to increase training throughput. | |
| --- | |
| ## Visualization of Multipack with Flash Attention | |
| Because Flash Attention simply drops the attention mask, we do not need to | |
| construct a 4d attention mask. We only need to concatenate the sequences into | |
| a single batch and let flash attention know where each new sequence begins. | |
| 4k context, bsz =4, | |
| each character represents 256 tokens | |
| X represents a padding token | |
| ``` | |
| 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 | |
| [[ A A A A A A A A A A A ] | |
| B B B B B B ] | |
| C C C C C C C ] | |
| D D D D ]] | |
| [[ E E E E E E E E ] | |
| [ F F F F ] | |
| [ G G G ] | |
| [ H H H H ]] | |
| [[ I I I ] | |
| [ J J J ] | |
| [ K K K K K] | |
| [ L L L ]] | |
| ``` | |
| after padding to longest input in each step | |
| ``` | |
| 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 | |
| [[ A A A A A A A A A A A ] | |
| B B B B B B X X X X X X ] | |
| C C C C C C C X X X X ] | |
| D D D D X X X X X X X ]] | |
| [[ E E E E E E E E ] | |
| [ F F F F X X X X ] | |
| [ G G G X X X X X ] | |
| [ H H H H X X X X ]] | |
| [[ I I I X X ] | |
| [ J J J X X ] | |
| [ K K K K K ] | |
| [ L L L X X ]] | |
| ``` | |
| w packing ( note it's the same effective number of tokens per step, but a true bsz of 1) | |
| ``` | |
| 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 | |
| [[ A A A A A A A A A A A B B B B B | |
| B C C C C C C C D D D D E E E E | |
| E E E E F F F F F G G G H H H H | |
| I I I J J J J K K K K K L L L X ]] | |
| ``` | |
| cu_seqlens: | |
| [[ 0, 11, 17, 24, 28, 36, 41 44, 48, 51, 55, 60, 64]] | |
| ## Multipack without Flash Attention | |
| Multipack can still be achieved without Flash attention, but with lower packing | |
| efficiency as we are not able to join multiple batches into a single batch due to | |
| context length limits without flash attention. We can use either Pytorch's Scaled | |
| Dot Product Attention implementation or native Pytorch attention implementation | |
| along with [4d attention masks](https://github.com/huggingface/transformers/pull/27539) | |
| to pack sequences together and avoid cross attention. | |
| <img src="./images/4d-mask.png" alt="axolotl" width="800"> | |