| # Advanced Insights: Attention Masks with KV-Caching | |
| ## Key Pitfalls in Complex Attention Implementations | |
| ### Dimension Evolution with Caching | |
| ```python | |
| # Crucial dimension transitions in cached attention: | |
| [b, s, d_model] -> [b, s+cache, d_c] -> [b, s+cache, d_model] -> [b, num_h, s, d_head] | |
| ``` | |
| The non-obvious trap: even with growing K/V cache, attention output dimensions must match query length, not cached length. | |
| ### Mask Causality with Growing Cache | |
| Standard causal masks break with KV-caching - they don't account for position-dependent attention patterns across cached sequences. Critical edge cases: | |
| - Token at position `i` must attend to `[0:start_pos+i]` | |
| - Naive mask extension leads to incorrect causality preservation | |
| - Performance impact of position-wise mask generation | |
| ### Optimization Considerations | |
| 1. Memory vs Compute tradeoff: Precomputing extended masks vs generating per position | |
| 2. Batch dimension handling: Mask broadcasting impacts memory usage | |
| 3. Fused attention patterns may break with custom mask handling | |
| ## Debugging Strategy for Non-Obvious Cases | |
| Monitor these dimension transitions for subtle bugs: | |
| ```python | |
| C_KV.shape # Should grow: [b, s₁, d_c] -> [b, s₁+s₂, d_c] | |
| K_state.shape # Post-projection growth affects attention patterns | |
| att_output.shape # Must maintain query dimensions despite K/V growth | |
| ``` | |
| ## Practical Example: DeepSeek's MLA Edge Case | |
| In Multi-Latent Attention, the compressed KV cache introduces subtle interactions with attention masks due to: | |
| 1. Joint compression affecting position-dependent patterns | |
| 2. Non-standard dimension flow through compression/decompression | |
| 3. Mask causality preservation across cached compressed states |