DEV Community

Tri Vo
Tri Vo

Posted on

Understanding the Design of Optimizers with me

Ok, it's midnight during Halloween, and I'm talking about Optimizers. Such a thrill lol. And the goal today is to teach you how AdamW is calculated mathematically, and what the intention is behind its design.

So, what is an optimizer in the context of LLMs? First of all, this name to me is rather deceptive. This is actually less of an object and more of a verb/ methodology/ concept.

Optimizer is a WAY to update the parameters in our LLMs during training. Specifically, it's how to update each weight value in, for example, our attention matrices or in the linear layer of our MLP layer during backprop.

In the simplest form, with "chain rule", we have this classic optimizer called Stochastic Gradient Descent (SGD):

θt=θt1ηθL(θt) \theta_t = \theta_{t-1} - \eta \nabla_{\theta} \mathcal{L}(\theta_t)

, where η\eta is the learning rate and θLt(θt)\nabla_{\theta} \mathcal{L}_t (\theta_t) is the derivative of our Loss function with respect to θ\theta .

One question I have here is "Why Stochastic?". None of the calculations in the above formula results in a different solution every time we solve it. True. The answer doesn't lie in the formula, but is cleverly hidden in the LLM's implementation. Because training an LLM requires tons of inputs, calculating the gradients of all (or batch gradient descent) would be terribly expensive. Therefore, we would only sample a few inputs to calculate our loss, or so-called mini-batch gradient descent. And this sampling has introduced stochasticity into our calculation.
(FYI: the full update is also usually called Deterministic Gradient Descent (DGD), but this is less popular imo.)

Remember, our goal is to learn AdamW, and this design is very simple and far from AdamW's. So the question is, how to improve the updates? Would always greedily (and blindly) follow the best direction at the immediate location be our best option? Here is where the fun parts begin.

Example of a ravine

Considering the above image from a global perspective, it would make more sense if the blue point approached the local minimum, but it actually oscillates a lot before convergence. You can think of this like building a road downhill. It's easier for a car to drive in zigzags but longer curves than to drive straight down the hill and obviously faster. In the context of a loss landscape, the next optimal reduction doesn't guarantee the best direction to the lower contour, as the contours can have small spikes that are not those "easier options". Therefore, new "routing methods" are introduced to tackle this local exploration behavior.

Momentum

What if we have a history of previous steps, and we'll make the next direction based on the tendency from the past decisions. In physics, this concept is called momentum. This "internal force" will kick you out of your intended next location based on your past movements - just like the "momentum" when you jump off a moving bus. With this extra movement, you can EXPLORE a better place in your contour map - it's like escaping from your destined death! or just a longer route. Here, we record the effect of past gradients in vtv_t like the following:

vt=βvt1+(1β)θL(θ)v_t = \beta v_{t-1} + (1 - \beta)\nabla_{\theta} \mathcal{L}(\theta)

θt=θt1ηvt\theta_{t} = \theta_{t-1} - \eta v_t

It's great that we add more spices to our previous vanilla SGD optimizer. What other spices can we use for this vanilla recipe? With the goal to EXPLORE other directions, we can weigh the directions based on their frequency.

vt=vt1+(θL(θ))2v_t = v_{t-1} + (\nabla_{\theta} \mathcal{L}(\theta))^2

θt=θt1ηvt+ϵθL(θ)\theta_{t} = \theta_{t-1} - \frac{\eta}{\sqrt{v_t} + \epsilon} \nabla_{\theta} \mathcal{L}(\theta)

AdaGrad adds the normalized (square) gradient over time and divides the current gradient by the L2 regularization term. In simple words, as the gradients get larger, they would be penalized harder and vice versa. And ϵ\epsilon is a very small number, around 10810^{-8} here to make sure we do not divide by 0.

This strategy is helpful for sparse data where there is a lack of some specific signals, as we can give them more presence. Likewise, it reduces the impact of dominant decisions. In the end, this is very helpful for smoothing the gradient signal.

However, Summation is a simple yet naive approach for regularization. As we accumulate the gradients over time linearly, especially during very long training steps, it would bundle up. If we don't introduce any "restriction" for our regularization, it would eventually collapse our gradients to zero. And Root Mean Squared Propagation (RMSProp) was invented for this problem. This is the same as AdaGrad, but it introduces some "weightings" for our regularization.

vt=βvt1+(1β)(θL(θ))2v_t = \beta v_{t-1} + (1 - \beta)(\nabla_{\theta} \mathcal{L}(\theta))^2

θt=θt1ηvt+ϵθL(θ)\theta_{t} = \theta_{t-1} - \frac{\eta}{\sqrt{v_t} + \epsilon} \nabla_{\theta} \mathcal{L}(\theta)

, with β\beta is from 0 to 1. In accumulation, the further gradient would have less impact on the current regularization because:
vt=(1β)(θtL(θt))2+βvt1=(1β)(θtL(θt))2+β[(1β)(θt1L(θt1))2+βvt2]=(1β)(θtL(θt))2+β(1β)(θt1L(θt1))2+β2vt2 v_t = (1 - \beta)(\nabla_{\theta_t} \mathcal{L}(\theta_t))^2 + \beta v_{t-1} = (1 - \beta)(\nabla_{\theta_t} \mathcal{L}(\theta_t))^2 + \beta \left[ (1 - \beta)(\nabla_{\theta_{t-1}} \mathcal{L}(\theta_{t-1}))^2 + \beta v_{t-2} \right] = (1 - \beta)(\nabla_{\theta_t} \mathcal{L}(\theta_t))^2 + \beta (1- \beta)(\nabla_{\theta_{t-1}} \mathcal{L}(\theta_{t-1}))^2 + \beta^2 v_{t-2}

After each step, the later gradients would be times to β\beta once, exponentially degrading their magnitude as β\beta < 1 . Tuning this hyperparameter would solve both gradient vanishing and gradient exploding for us!

Ok, we have explored 2 approaches to adaptively steer the guiding wheel - Momentum and Gradient smoothing (RSMProp). What if we combine both of the ideas? - This is where we get the Adam optimizer. In other words, replacing the plain θL(θ)\nabla_{\theta} \mathcal{L}(\theta) updates in RSMProp with the Momentum-inspired updates.

rsmprop

, where v^t\hat{v}_t is from RSMProp and m^t\hat{m}_t is from Momentum.

Oke, finally we have reached the focus point of this discussion.

Now, the first two lines make sense for us since they are straight out of the previous approaches. However, where do the next two lines come from?

First, let's expand the mtm_t formula to analyze how the values change over time.
Stage 1
We can reach this general formula if we keep expanding. And since m0m_0 is usually initialized as 0, we can shorten the final formula.
Stage 2 - General formula
Then take the expectation of both sides. Since the expectation of sum equals the sum of expectation, we can mathematically deduce the following equation.
Stage 3 - Expectation
The gradient distribution here could be either stationary or unstationary, telling us how drastically the values oscillate away from the mean value. And with the summation follows a geometric series, we have our final formula as:
Stage 4 - Geometric
If stationarity, the expected value for the gradient distribution is approximately constant, termed μ\mu , across all training steps and ζ\zeta is 0.
If non-stationarity, ζ\zeta is not 0 to account for the difference with the stationary. However, ζ\zeta can still remain small because the decay rate β1\beta_1 is chosen such that the further gradients' contributions are exponentially downweighted to insignificance. In the original Adam paper, β1\beta_1 is tuned to a very large number (0.9); hence, the previous gradient is scaled to (1 - 0.9) = 0.1. Over time, the effect of non-stationarity or the value of ζ\zeta is insignificant to speak of. And this gives us the final estimation of mtm_t equal to a scaled mean of gradients.
Stage 5 - final
This analysis shows one huge issue with Adam's update policy. When t is small (t = 1), μ\mu is scaled by 10.91=0.11 - 0.9^1 = 0.1 . When t is getting bigger (t = 20), μ\mu is scaled by 10.920=0.881 - 0.9^{20} = 0.88 . In other words, the effect of mtm_t heavily depends on the time step, and is affected by unbalanced scalings. This has encouraged the author of Adam to regulate mtm_t with what we see in the third line of the Adam formulas - a division over 1β1t1 - \beta_1^t :
m hat
We have done with the difficult part. Now, moving from Adam to AdamW is very easy to digest. The "W" in AdamW actually means "Weight Decay". The idea here is to make the next weight's values not deviate so far away from the previous, creating a more stable movement for our updates. And we implement that idea by adding a part of the current weights in the update. This concept is very favoured in different fields. It is very similar to the residual path in convolution networks or transformer architecture. And as we move to more advanced optimizers, this "trick" starts to appear everywhere, such as in the Lion and the Sophia optimizers.
adamw

It's the end of the blog, and I'll leave you guys to go now. The below would be more personal and uneducated in all senses.

Cheers if you stay until this. Here, I'll express myself a bit and say the purpose of this blog. I'm in a research community of the very top and sharpest minds all over the world. I'm in a team of 4 led by an MS/PhD from NUS, and the topic is to compare AdamW and Muon optimizers in larger-scale post-training stages.

Benchmarking these optimizers is easy in the sense that the optimizers are the deepest component of the transformer architecture. Naturally, you have more control over the influencers and less noise in the results when benchmarking. Therefore, the results would be more reliable, and we could think of a reasoning without major assumptions.
However, benchmarking optimizers is also very difficult due to the depth of the required knowledge. These are the first building blocks of our architecture. And a small adjustment of the foundation can make the whole building collapse. Therefore, I feel like I need to understand deep down the mathematics behind these optimizers to possibly explain, for example, how a specific behavior could be explained by adding one more variable in the formula. Moreover, due to its nature, we would need to measure both the model performance and the hardware performance to ensure a fair comparison. And working at such a low level does cause me some issues with the engineering parts.
Therefore, I want to learn and share about this topic. Both to prove that it's not too terrible to dive deep into maths, but also to help myself keep up with the team's workload. I really want to contribute to the greatest and be a firm foundation for my future self.

Hope I'll get myself to learn, write, and share about more interesting topics in the near future. But these are what we have gone through today:
SGD x
Momemtum x
AdaGrad x
RMSProp x
Adam / AdamW x
Muon o

References:
[ADAM: A METHOD FOR STOCHASTIC OPTIMIZATION] https://arxiv.org/pdf/1412.6980
[Understanding Deep Learning Optimizers: Momentum, AdaGrad, RMSProp & Adam] https://towardsdatascience.com/understanding-deep-learning-optimizers-momentum-adagrad-rmsprop-adam-e311e377e9c2/
[ML@Purdue] https://www.instagram.com/mlpurdue/

Top comments (1)

Collapse
 
linh_le_d925beb15d1e18e94 profile image
Linh Le

Great post! The way you showed the progression from SGD to AdamW was really clear and easy to follow. Good luck with your benchmarking research!