KV Caching Explained: Optimizing Transformer Inference Efficiency
Introduction
When AI models generate text, they often repeat many of the same calculations, which can slow things down. Key-Value caching is a technique that helps speed up this process by remembering important information from previous steps. Instead of recomputing everything from scratch, the model reuses what it has already calculated, making text generation much faster and more efficient.
In this blogpost, weβll break down KV caching in an easy-to-understand way, explain why itβs useful, and show how it helps AI models work faster.
Prerequisites
To fully grasp the content, readers should be familiar with:
- Transformer Architecture: Familiarity with components such as the attention mechanism.
- Autoregressive Modeling: Understanding of how models like GPT generate sequences.
- Linear Algebra Basics: Concepts like matrix multiplication and transposition, which are essential for understanding attention computations.
This π BLOG should cover up most of the prerequisites needed for this article.
click here for some of the most essential takeaways.
- attention weight has a shape of
- masked multi-head attention allows each token to be represented by itself and all the previous tokens.
- to generate a new token the model needs to look at all the previous tokens and their representations by their preceding tokens
Standard Inference and the Rise of KV Caching
When a model generates text, it looks at all the previous tokens to predict the next one. Normally, it would repeat the same calculations for every new token, which can slow things down.
KV caching solves compute overlap by remembering these calculations from previous steps, this can be achieved by storing the intermediate states of attention layers during inference.
How Does KV Caching Work?
Step-by-Step Process
- First Generation: When the model sees the first input, it calculates and stores its keys and values in the cache.
- Next Words: For each new word, the model retrieves the stored keys and values and adds the new ones instead of starting over.
- Efficient Attention Computation: calculate attention using the cached and along with the new (query) to compute the output.
- Update Input: add the newly generated token to the input and until we finish generating.
The process is illustrated below:
Token 1: [K1, V1] β Cache: [K1, V1]
Token 2: [K2, V2] β Cache: [K1, K2], [V1, V2]
...
Token n: [Kn, Vn] β Cache: [K1, K2, ..., Kn], [V1, V2, ..., Vn]
KV Caching | Standard Inference |
---|---|
In the table above we used a for better visuals, note that this number can be significantly bigger than what we have presented.
Comparison: KV Caching vs. Standard Inference
Hereβs how KV caching compares to the regular generations :
Feature | Standard Inference | KV Caching |
---|---|---|
Computation per Word | The model repeats the same calculations for every word. | The model reuses past calculations for faster results. |
Memory Usage | Uses less memory at each step, but memory grows with longer texts. | Uses extra memory to store past information, but keeps things efficient. |
Speed | Gets slower as the text gets longer because it repeats work. | Stays fast even with longer texts by avoiding repeated work. |
Efficiency | High computational cost and slower response times. | Faster and more efficient since the model remembers past work. |
Handling Long Texts | Struggles with long texts due to repeated calculations. | Perfect for long texts as it remembers past steps. |
KV caching makes a big difference in speed and efficiency, especially for long texts. By saving and reusing past calculations, it avoids the need to start over each time, making it much faster than the regular way of generating text.
Practical Implementation
This is a simplified example of implementing KV caching in PyTorch:
# Pseudocode for KV Caching in PyTorch
class KVCache:
def __init__(self):
self.cache = {"key": None, "value": None}
def update(self, key, value):
if self.cache["key"] is None:
self.cache["key"] = key
self.cache["value"] = value
else:
self.cache["key"] = torch.cat([self.cache["key"], key], dim=1)
self.cache["value"] = torch.cat([self.cache["value"], value], dim=1)
def get_cache(self):
return self.cache
When using the transformers library this behavior is enabled by default through the use_cache
parameter, you can also access multiple caching methods through the cache_implementation
parameter, here's a minimalistic code :
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceTB/SmolLM2-1.7B')
model = AutoModelForCausalLM.from_pretrained('HuggingFaceTB/SmolLM2-1.7B').cuda()
tokens = tokenizer.encode("The red cat was", return_tensors="pt").cuda()
output = model.generate(
tokens, max_new_tokens=300, use_cache = True # by default is set to True
)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
We benchmarked the code above with/without kv caching on a T4 GPU we got the following results :
with KV Caching | Standard Inference | Speedup |
---|---|---|
11.7 s | 1min 1s | ~5.21x times faster |
Conclusion
KV caching is a simple but powerful technique that helps AI models generate text faster and more efficiently. By remembering past calculations instead of repeating them, it reduces the time and effort needed to predict new words. While it does require extra memory, this method is especially useful for long conversations ensuring fast and efficient generation.
Understanding KV caching can help developers and AI enthusiasts build faster, smarter, and more scalable language models for real-world applications.
I would like to extend my sincerest gratitude to Aritra Roy Gosthipaty π€ for his invaluable support, feedback, and dedication in developing this blog post.