Been trying to understand the scaling in the attention formula, specifically sqrt(d_k). It confused me a bit why do we need to divide at all?
I was confused because we subtract each value with the max value inside softmax anyway (so exp doesn't explode our numbers), so why do we need to scale before this step as well?
Turns out the difference lies between numerical stability and statistical calibration.
Division vs. Subtraction
When we divide by sqrt(d_k), we're reducing the magnitude of each value proportionally, which shrinks the differences between them for example, [100, 102, 103] becomes [10.0, 10.2, 10.3], where the 2-unit and 1-unit gaps become 0.2 and 0.1. This brings the values closer together before they reach softmax.
In contrast, when we subtract (like subtracting the max in softmax), we shift where the values sit on the number line without changing the differences between them at all [100, 102, 103] becomes [-3, -1, 0], but the gaps remain 2 and 1.
At first, I thought: if we're just reducing magnitudes for softmax, why not simply subtract the max value like we do inside softmax for stability? But then it occurred to me that subtraction doesn't actually bring the values closer together it only shifts them.
Preserving Proportions for Softmax
The problem is that I need to preserve the proportional relationships between numbers (the ratios like 102/100 = 1.02 stay the same after division), because softmax relies on these relative differences to produce meaningful probabilities. I don't want to lose how much bigger one value is compared to another.
However, I also can't keep the absolute magnitude of these differences too large, because softmax's exponential would exaggerate them further turning a 3-unit spread into a distribution like [0.09, 0.24, 0.67] where one value dominates.
So division is the perfect solution: it keeps the proportionality the same (10.2/10.0 = 1.02, just like before) while bringing the absolute differences closer (from 2 units to 0.2 units), ensuring values don't look way too far apart before softmax amplifies them into a more balanced distribution like [0.30, 0.33, 0.37].
Another major reason is future-safe design. As d_k increases (say from 64 to 512), the dot product naturally grows larger since we're summing more terms, but this growth doesn't represent meaningful differences in attention it's just an artifact of dimensionality. By dividing by sqrt(d_k), we compensate for this growth and keep the scale consistent: whether d_k is small or large, the proportional relationships remain stable.
Dividing by d_k directly would shrink values too aggressively. Using sqrt(d_k) is the right balance because the variance of the dot product grows linearly with d_k, and dividing by sqrt(d_k) keeps the standard deviation roughly constant. This ensures that the scale of the values entering softmax remains consistent, no matter the dimensionality.
Key Insight
The key insight is that values need to be close together before entering softmax, so we can let softmax do the exaggeration through its exponential function in a controlled way, while subtraction just ensures numerical stability without affecting the relative distances that softmax actually cares about.
Basically there are 3 important things happening here:
- Division changing how far apart the values are (brings them closer)
- Subtraction changing where they sit on the number line (doesn't change separation)
- We want values close together BEFORE softmax, so softmax's exponential amplification produces a reasonable distribution, not an extreme one
Let the softmax do the exaggeration.
Top comments (0)