DEV Community

Jimin Lee
Jimin Lee

Posted on

Understanding the KV Cache (feat. Self-Attention)

If you’ve been following the world of LLMs, you’ve probably heard of the KV Cache. It’s a technique that’s mentioned constantly, which tells you it’s both widely used and incredibly important. Today, we're going to break down what the KV Cache is and why it’s such a big deal.


LLMs are Slow

Let's start with a simple fact: LLMs are slow. You might think, "Well, they're huge, so of course they're slow." While their massive size is the primary reason, it's not the whole story.

There are two other major culprits:

  1. The Self-Attention mechanism.
  2. The Auto-Regressive generation method.

The KV Cache is a clever solution designed to tackle the performance bottleneck created by this combination. To understand the solution, we first need to understand the problems.


Grasping Context: Self-Attention

At the heart of why Transformers work so well is Self-Attention. Just like the name suggests, it’s about paying attention to oneself—in this case, "self" refers to the input sequence.

Let's take an example. Consider the sentence: "I went to the bank to deposit money." The word "bank" can mean a financial institution or a riverbank. Yet, we instantly understand it means a place for money. How do we do that?

If someone just gave you the word "bank" in isolation, you couldn't be sure of its meaning. You might guess it's a financial institution since that's a common usage, but it would just be a guess.

However, in the full sentence, the other words provide context. The core idea of Self-Attention is to mimic this process. Imagine a Transformer processing the word "bank." Just like us, it can't be certain of the meaning from the word alone. So, what does it do? It looks at the other words in the sentence to figure out their relationships.

  • "I" isn't highly related to "bank."
  • "went" has some relevance to "bank."
  • "money" is highly related to "bank."

By assessing the relevance of every word to "bank" (including itself), the Transformer gains a much richer understanding of the input. Since "money" is a highly relevant word, the model correctly infers that "bank" is more likely to be a financial institution than a riverbank.

How is Self-Attention Actually Calculated?

So far, we've talked about "relevance" in abstract terms. How does a machine actually calculate it? To fully grasp this, you'd need to look at the mathematical formulas, but we'll skip the heavy math today and stick to the core concepts.

For each word (or, more accurately, token) in the input, the model calculates three distinct values: a Query (Q), a Key (K), and a Value (V). For our example sentence, "I went to the bank to deposit money," the model would compute Q_I, K_I, V_I, Q_went, K_went, V_went, ..., and so on for every word. For an 8-word sentence, that's a total of 8 * 3 = 24 vectors.

Here’s what these three roles do in Self-Attention:

  • Query (Q): This is the current word's "question." It's like the word shouting, "Hey everyone, who has information relevant to me?"
  • Key (K): This is every other word's "nametag" or "keyword." It responds to the Query's question, saying, "Hmm, I might be relevant to you! Check out my nametag."
  • Value (V): This is the word's actual "meaning" or "substance." If a Key is a good match for a Query, the model then says, "Okay, here’s my actual information for you," and passes along its Value.

Analogy: Finding a Book in a Library

Still a bit fuzzy? Let’s try an analogy. Imagine you're in a library looking for a book about the "History of AI," but you don't know the exact title. You head to the computer science section and start browsing the shelves. Reading every single book would take forever, so you start by looking at the titles.

"Operating System Theory" and "Introduction to Compilers" seem less relevant. "Pattern Recognition" seems more relevant, and "Introduction to AI" looks like a great match.

But as the saying goes, don't judge a book by its cover. You pull out a few promising titles and skim their contents:

  • "Operating System Theory": All about how to build OSes.
  • "Introduction to Compilers": Mostly about compiler theory, but has a small section on rule-based AI that could be applicable.
  • "Pattern Recognition": Covers a specific subfield of AI.
  • "Introduction to AI": Describes everything from early AI techniques to modern ones.

Just as you suspected, "Operating System Theory" isn't very relevant, while "Introduction to AI" is highly relevant. "Introduction to Compilers," which seemed irrelevant at first glance, turned out to have some small connection.

Let’s connect this back to Q, K, and V.

  • Your search topic, "History of AI," is the Query (Q). Your goal is to find books that are relevant to this query.
  • The book titles—"Operating System Theory," "Introduction to Compilers," etc.—are the Keys (K). They are keywords that summarize the content. You made an initial judgment by comparing your Query to these Keys.
  • The actual content of each book is the Value (V).

Now, how would you calculate the total relevance of each book to your query? Just comparing your query ("History of AI") to the titles (K) isn't enough. A title, no matter how well-written, can't capture everything inside. To get a more accurate score, you also need to consider the actual content (V).

In other words, the process is:

  1. First, determine the initial relevance between your query (Q) and each book's title (K).
  2. Based on that relevance score, decide how much of the book's actual content (V) to factor into your final decision.

The higher the relevance from step 1, the more the book's content influences the final result.

Back to the Transformer

Let's return to our original example: "I went to the bank to deposit money."

We want to calculate the attention for the first word, "I." So, "I" becomes our Query. More precisely, the model doesn't use the word "I" directly; it computes a Q vector that represents "I." This is similar to how when you search for "History of AI," your brain interprets the meaning of that phrase rather than just the literal string of characters.

Next, the model calculates the K and V vectors for all words in the sentence. In our library analogy, the books' titles (K) and contents (V) were already fixed. But here, since the model doesn't know what input it will get, it computes K and V dynamically. (You can think of the library analogy as the book authors having pre-calculated the K and V for you).

Now we have the Q vector for "I" and the K and V vectors for every word. The process looks like this:

  1. Compare Q_I with K_I. As it's the same word, the relevance is likely high.
  2. Compare Q_I with K_went. As a subject-verb pair, there’s some relationship.
  3. Compare Q_I with K_to. Semantically and grammatically, the relevance doesn't seem very high.
  4. This continues all the way to K_money.

This completes the step of finding the relevance between the Query and each Key. Based on these relevance scores, the model combines the Value vectors of each word to determine the final attention output.

  • The relationship between "I" and "I" is calculated using (Q_I, K_I, V_I).
  • The relationship between "I" and "went" is calculated using (Q_I, K_went, V_went).
  • The relationship between "I" and "to" is calculated using (Q_I, K_to, V_to).
  • ...and so on.

Once this is done, the Self-Attention calculation for the first word, "I," is complete. The next step is to repeat this entire process for "went," "to," "the," "bank," and every other word. For "bank," the calculations would look like:

  • (Q_bank, K_I, V_I): The relevance between "bank" and "I."
  • (Q_bank, K_went, V_went): The relevance between "bank" and "went."
  • ...
  • (Q_bank, K_money, V_money): The relevance between "bank" and "money."

This final "relevance score" between each pair of words is the Self-Attention value. Calculating Self-Attention means calculating the relationship between every single word in the input sequence. Because these values represent inter-word relationships, they capture the context of the sentence, allowing the model to know that the bank in "I went to the bank to deposit money" is not a riverbank.

What Does Calculating Q, K, V Actually Mean?

This part gets a little more mathematical. Feel free to skip to the next section if you'd like.

Like other machine learning algorithms, Transformers only understand numbers. They can't process words like "went" or "bank." So, the first step is to convert each word into a list of numbers, also known as a vector. This is called an embedding.

  • I: [213, 92, 42, 89, 0, 21]
  • went: [5, 2, 43, 99, 92, 111]

Calculating Q, K, and V is the process of transforming a word's embedding into three new vectors. For example, let's say the embedding for "I" is E_I:

  • E_I: [213, 92, 42, 89, 0, 21]
  • Q_I: [234, 1, 53, 111, 5, 4]
  • K_I: [4, 23, 62, 13, 34, 93]
  • V_I: [24, 12, 32, 51, 112, 34]

These numbers are just for illustration. In reality, the model multiplies the input embedding by three different matrices called W_Q, W_K, and W_V to get the final Q, K, and V vectors.

  • E_I x W_Q = Q_I
  • E_I x W_K = K_I
  • E_I x W_V = V_I

And how are W_Q, W_K, and W_V created? Just like all other parameters in a Transformer, they are learned during the training process. They start as random, meaningless values and gradually become meaningful through training.

Computational Cost: O(N^2)

Self-Attention is all about calculating the relationship between every pair of words. The word "I" must be compared with "I," "went," "to,"..., "money." And "money" must be compared with "I," "went," "to,"..., "money."

If the input sentence has N words, the first word has to be compared with N words. Since this process must be repeated for all N words, the total computational complexity of Self-Attention is N x N, or O(N^2).

What does quadratic complexity mean in practice?

  • If your input has 10 words, you need 10 x 10 = 100 calculations.
  • If your input has 100 words, you need 100 x 100 = 10,000 calculations.

See the problem? The input length grew by 10x, but the computation required grew by 100x. As the input gets longer, the computation grows exponentially.

While Self-Attention is a powerful technique for understanding context, it is also a very expensive and time-consuming operation.


The Grind of Auto-Regressive Generation

Before we introduce the KV Cache, we need to talk about another key characteristic of LLMs: auto-regressive generation. The term might sound unfamiliar, but the concept is straightforward. Let's see how an LLM generates a sentence.

Suppose you give an LLM the prompt: "What is the tallest mountain?" It will likely respond, "It is Mount Everest."

Let's break down how it gets there, step by step:

  1. The LLM predicts the most likely next token after "What is the tallest mountain?" -> "It"
  2. The LLM appends this new token to the input. It now predicts the most likely token after "What is the tallest mountain? It" -> "is"
  3. The LLM appends "is." It now predicts the most likely token after "What is the tallest mountain? It is" -> "Mount"
  4. The LLM appends "Mount." It now predicts the most likely token after "What is the tallest mountain? It is Mount" -> "Everest"
  5. The LLM appends "Everest." It now predicts the most likely token after "What is the tallest mountain? It is Mount Everest" -> "<END>"

When the model generates a special <END> token, it stops. The final output is "It is Mount Everest."

Essentially, an LLM generates text by picking the most probable next word, appending it to the sequence, and then repeating the process with the new, longer sequence until it decides to stop.

  • Regressive means predicting a new value based on past values. Predicting today's weather based on yesterday's is a regressive system.
  • Auto means that the model's own output is fed back in as the next input. In our example, the generated token "It" became part of the input for the very next step.

Putting them together, this method of generating text one token at a time is called Auto-Regressive generation.

Predicting the Next Token

Let's zoom in on just the first step of that process:

  1. The LLM receives the input "What is the tallest mountain?".
  2. The LLM calculates Self-Attention for this entire input.
  3. Based on the Self-Attention output for the final token, "mountain?", it predicts the next token. (Technically, this output goes through a fully connected layer to become a hidden state, which is then used for prediction).
  4. The result is the token "It".

You might wonder if looking only at the output for "mountain?" is enough. It is, because thanks to Self-Attention, the representation for "mountain?" already contains information and relationships with all the other words: "What," "is," "the," and "tallest." Remember, the Query for "mountain?" was compared against the Keys and Values of every word in the sequence.

And how are the Q, K, and V vectors for each word calculated? As mentioned before, by multiplying their embeddings by the pre-trained W_Q, W_K, and W_V matrices.

This means that in step 1, we calculated the Q, K, and V values for "What," "is," "the," "tallest," and "mountain?".

The stage is now set for our hero, the KV Cache.


Enter the KV Cache

Let's move to step 2. The input is now "What is the tallest mountain? It". The LLM needs to predict the next word, so it will calculate Self-Attention using "It" as the Query. To do this, it needs to convert "It" into a Query vector and then compare it with the Keys and Values of all the other words ("What," "is," "the," "tallest," "mountain?").

But wait! When we were calculating the K and V for "What" in this new step, we realized something: we already calculated them in the previous step when we generated "It." The same is true for "is," "the," "tallest," and "mountain?". There's no need to re-calculate the exact same values.

So, instead of re-calculating the K and V values for the old tokens, we can just reuse the ones we computed in the previous step.

All we have to do is save (or cache) the K and V values from the previous step instead of throwing them away. This saves us from performing complex matrix multiplications (E * W_K and E * W_V)over and over again, dramatically speeding up the process.

This technique of saving and reusing the Key and Value vectors from previous steps is called the KV Cache.

There are two key things to keep in mind about the KV Cache:

  1. You still need to calculate the K and V for the newly generated token. In our example, the K and V for "It" must be computed because it wasn't part of the input in the previous step. This new K and V pair is then added to the cache for the next step.

  2. You might wonder why we don't cache the Q vector. When generating the next token, we only care about the Self-Attention for the very last token in the sequence. We always need the Q of the current last token, and this Q, just like its K and V, has never been calculated before. Caching old Q vectors would be useless because in the next auto-regressive step, those tokens are now in the past, and we don't need to query from their position again. In short: to calculate attention for the current token, you need its new Q and all past K and V values.


The Downsides of KV Cache

By eliminating redundant and complex calculations, the KV Cache provides a massive speed boost for LLM inference. But as they say, there's no such thing as a free lunch.

The KV Cache works by storing previously computed values in memory instead of discarding them. This means it consumes memory space.

In a real Transformer, Self-Attention is implemented as Multi-Head Self-Attention (you can think of it as running multiple Self-Attention calculations in parallel), and each layer of the Transformer has its own distinct attention values. Therefore, the memory required for the KV Cache is roughly:

(Number of tokens) x (Number of attention heads) x (Number of layers) x (Dimension of K and V vectors) x ...

As a result, the size of the KV Cache grows rapidly with longer input sequences and larger models (which have more heads and layers). Since the KV Cache is typically stored in the GPU's VRAM, this necessitates GPUs with very large amounts of VRAM.

This memory consumption is a serious challenge, especially for modern models that support context windows of hundreds of thousands or even millions of tokens. To mitigate this, techniques like Multi-Query Attention (sharing K and V across attention heads) and Grouped-Query Attention are often used.


Wrapping Up

Today, we explored the KV Cache, a cornerstone optimization that has dramatically improved Transformer inference performance.

Here are the key takeaways:

  • Self-Attention uses Q, K, and V vectors to understand context, but this is computationally expensive (O(N^2)).
  • LLMs generate text using an auto-regressive method, where each new word is added to the input for the next step.
  • Re-calculating K and V for all previous tokens at every step is extremely inefficient.
  • The KV Cache solves this by storing and reusing K and V values, eliminating redundant computations.
  • The result is a huge boost in inference speed, but it comes at the cost of increased memory usage, especially for long sequences and large models.

Thanks for reading this deep dive!

Top comments (0)