In this “Hands-On Transformer Deep Dive” series, we go step-by-step through the algorithms and components of modern Transformers, with working code and engineering insights. Follow along to deepen your understanding — and build your own Transformers from scratch.
In this article, we dive deep into the core attention mechanism used in most of today’s Transformer models: the Masked Scaled Dot-product Attention. We’ll implement it from scratch using only PyTorch, and look into the specifics of when and where to apply the scale, mask, dropout, and why.
Introduction
Transformers have become the foundation of modern generative LMs. The attention mechanism lies in its core. There are many flavors of attention mechanisms, e.g., Additive Attention (Bahadnau, 2014), Dot-product Attention (Luong, 2015), Scaled Dot-product Attention (Vaswani et al., 2017), masked attention (used for padding and causal decoding), multi-head attention (which also has multiple variants). The masked scaled dot-product attention is the foundational building block of all the autoregressive GPT-like models prevalent today.
Implementation
The attention mechanism enables LLMs to learn and generate context-dependent representations by letting each token “attend” to all tokens. The formula tells us its basic working, where given queries Q, keys K, values V, the attention output is calculated as follows (d_k is the dimension of the keys):
The implementation, however, includes a couple of more details:
- mask: to enable padding and causal attention (where a token can only “attend” to tokens that came before itself)
- dropout: a regularization method to prevent the model from relying too heavily on a few specific positions in the sequence
Below is the code implementing the masked scaled dot-product attention mechanism step-by-step:
###
# Masked Scaled Dot-product Attention Implementation
###
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def attention(query, key, value, mask=None, dropout=0.1):
# Step 1 & 2: dot-product and scale
d_k = key.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
# Step 3: mask
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Step 4: softmax
attn_weights = F.softmax(scores, dim=-1)
# Step 5: dropout
attn_weights = nn.Dropout(dropout)(attn_weights)
# Step 6: weighted sum
output = torch.matmul(attn_weights, value)
return output
For easy demonstration this is implemented as a function. There is also a PyTorch module implementation at the end of the article, which you can use to plug into your own PyTorch network.
Step-by-step Explanation and Nuances
Following the 6 steps we can see the actual formula is:
Where
Q: query tensor of shape (batch_size, seq_len, d_q)
K: key tensor of shape (batch_size, seq_len, d_k)
V: value tensor of shape (batch_size, seq_len, d_v), normally d_q, d_k, and d_v are the same
M: mask matrix of shape (seq_len, seq_len), 0 for masking and 1 for passing through
D: dropout, a probability p between 0 and 1, where each element has p probability to be set 0 and 1-p probability to be kept and scaled up to 1/(1-p) (to compensate for the removed elements and keep the expected sum)
Let’s look at each step and understand the nuances about why they have to be in this order.
Step 1. Dot Product
scores = torch.matmul(query, key.transpose(-2, -1))
Here we use dot product to compute the correlation between the query and key to get the raw attention score. All further operations build on these fundamental scores.
Step 2. Scale by sqrt(d_k)
d_k = query.size(-1)
scores = scores / math.sqrt(d_k)
Here we scale the raw attention scores to prevent the next step, softmax, from being too “peaky”. This stabilizes the gradients for models with a large dimension (d_k). As d_k grows, variances between dot product results also increases (roughly by d_k).
Why do this scaling before softmax? The softmax function is highly sensitive to large input differences, where higher variance causes the largest score to dominate (i.e., its probability becomes close to 1 and others close to 0). This is a “peaky” distribution and can lead to vanishing gradients and poor gradient flow, which hurt learning. Therefore, we need to do the scale by sqrt(d_k) regularization here before moving on to softmax.
Step 3. Mask
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
Mask is used for padding and/or causal attention (where a token can only “see and attend” to the tokens before it). It is typically an additive mask, adding a large negative number (e.g., float(‘-inf’)) to certain positions to block them. With softmax these positions’ probability then become near-zero.
Why apply the mask before softmax? Because we use softmax to calculate the probability distribution from attention scores. For the tokens the query shouldn’t “see” we need their probabilities to be 0, and for the rest we need their probabilities to sum to 1. Masking before softmax with -inf satisfies this. If we mask after softmax (e.g., with zero), the masked tokens have already contributed to the probability distribution, and it also causes the probabilities to no longer sum to 1.
On the other hand, what about masking before computing the attention scores, i.e., before Step 1. dot product? It seems intuitive to not let the query “see” the keys of tokens that it shouldn’t see in the first place, right? Unfortunately, with the nature of dot product computation, masking with 0 doesn’t really remove the influence of the corresponding positions’ attention scores and their probabilities in the following softmax step would also not be zero-ed out. Also, masking with -inf makes the computation itself impossible.
Step 4. Softmax
attn_weights = F.softmax(scores, dim=-1)
Apply softmax to the attention scores to compute the attention weights that sum to 1. The attention weights are the weight of each key that the query should pay attention to. After dot product the attention scores is a tensor of dimension (batch_size, (num_heads,) query_len, key_len), where query_len and key_len are the same in self-attention and are often noted as seq_len. We only need to compute weights of the keys which is in the last dimension, so dim=-1 tells softmax which dimension we’re interested in.
Step 5. Dropout
attn_weights = nn.Dropout(dropout)(attn_weights)
Here we apply the dropout regularization for smoother gradients and better generalization. The nn.Dropout(dropout) gives us a dropout module of the dropout rate we need (which randomly zeroes out p percent of the elements and scale the rest by 1/1-p). We then pass the attn_weights through this dropout module to apply it.
Why is the dropout applied after softmax? As explained in Step 3. mask section, softmax calculates the attention weights from attention scores, which is a probability distribution representing how much “attention” each key should get. If we applied dropout before softmax, setting some attention scores to 0 and scaling up the others, we’d totally mess up the probability distribution, not to mention that the elements we “dropped out” (set to zero) don’t necessarily get a zero probabilities if you consider how the softmax function works.
Step 6. Coupute Output Value
output = torch.matmul(attn_weights, value)
Finally, multiply the attention weights with the value tensor and we get our masked scaled dot-product attention output, hooray!
PyTorch Module Implementation
Here is a PyTorch module implementation that can be plugged into your PyTorch modules.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class Attention(nn.Module):
def __init__(self, mask=None, dropout=0.1):
super().__init__()
self.mask = mask
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value):
# Step 1 & 2: dot product and scale
d_k = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1))/math.sqrt(d_k)
# Step 3 mask
if self.mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Step 4 softmax
attn_weights = F.softmax(scores, dim=-1)
# Step 5 dropout
attn_weights = self.dropout(attn_weights)
# Step 6 output
output = torch.matmul(attn_weights, value)
return value
This is the first of a series articles diving deep into the Transformer model architecture and algorithm implementations. Next up we’ll look into multi-head attention and its variants. Stay tuned and tell me what you think and what you’d like to read!
References & Further Readings
- Vaswani et al., Attention is All You Need
- Harvard NLP, The Annotated Transformer
Top comments (0)