<?xml version="1.0" encoding="UTF-8"?>
<rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom" xmlns:dc="http://purl.org/dc/elements/1.1/">
  <channel>
    <title>DEV Community: Eke Chukwudi</title>
    <description>The latest articles on DEV Community by Eke Chukwudi (@eke_chukwudi_90).</description>
    <link>https://dev.to/eke_chukwudi_90</link>
    <image>
      <url>https://media2.dev.to/dynamic/image/width=90,height=90,fit=cover,gravity=auto,format=auto/https:%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Fuser%2Fprofile_image%2F3962628%2F5d8a167a-7016-4c62-a2ab-4bbb9e220592.jpg</url>
      <title>DEV Community: Eke Chukwudi</title>
      <link>https://dev.to/eke_chukwudi_90</link>
    </image>
    <atom:link rel="self" type="application/rss+xml" href="https://dev.to/feed/eke_chukwudi_90"/>
    <language>en</language>
    <item>
      <title>Intercepting Gradients in PyTorch: Preprocess the Update Before Your Optimizer Sees It</title>
      <dc:creator>Eke Chukwudi</dc:creator>
      <pubDate>Mon, 01 Jun 2026 12:52:00 +0000</pubDate>
      <link>https://dev.to/eke_chukwudi_90/intercepting-gradients-in-pytorch-preprocess-the-update-before-your-optimizer-sees-it-5ckj</link>
      <guid>https://dev.to/eke_chukwudi_90/intercepting-gradients-in-pytorch-preprocess-the-update-before-your-optimizer-sees-it-5ckj</guid>
      <description>&lt;p&gt;Most people tune the optimizer. Almost nobody touches the thing the optimizer actually eats: the gradient. But the gap between &lt;code&gt;loss.backward()&lt;/code&gt; and &lt;code&gt;optimizer.step()&lt;/code&gt; is a real hook point, and you can do useful work there. This is a short, runnable guide to intercepting gradients and transforming them before the update lands.&lt;/p&gt;

&lt;h2&gt;
  
  
  The mental model
&lt;/h2&gt;

&lt;p&gt;A training step is really four moves:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;Forward pass: compute predictions and loss&lt;/li&gt;
&lt;li&gt;Backward pass: &lt;code&gt;loss.backward()&lt;/code&gt; fills &lt;code&gt;parameter.grad&lt;/code&gt; for every parameter&lt;/li&gt;
&lt;li&gt;Update: &lt;code&gt;optimizer.step()&lt;/code&gt; reads those &lt;code&gt;.grad&lt;/code&gt; tensors and adjusts the weights&lt;/li&gt;
&lt;li&gt;Reset: &lt;code&gt;optimizer.zero_grad()&lt;/code&gt;
&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;The window we care about is between steps 2 and 3. After &lt;code&gt;backward()&lt;/code&gt;, the gradients exist as plain tensors in &lt;code&gt;p.grad&lt;/code&gt;. The optimizer has not touched them yet. That is where we intervene.&lt;/p&gt;

&lt;h2&gt;
  
  
  Step 1: a baseline step
&lt;/h2&gt;



&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch.nn&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;

&lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Sequential&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;ReLU&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;optimizer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;optim&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Adam&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="n"&gt;lr&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e-3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;loss_fn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;MSELoss&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;64&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;y&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;64&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zero_grad&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;loss_fn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;backward&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Standard. Now let's get between the last two lines.&lt;/p&gt;

&lt;h2&gt;
  
  
  Step 2: transform the gradients in place
&lt;/h2&gt;

&lt;p&gt;After &lt;code&gt;backward()&lt;/code&gt;, iterate the parameters and modify each &lt;code&gt;.grad&lt;/code&gt; before stepping:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;soft_threshold&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;g&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;lam&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e-4&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="c1"&gt;# shrink every gradient component toward zero by lam
&lt;/span&gt;    &lt;span class="c1"&gt;# this is the core operation behind wavelet denoising
&lt;/span&gt;    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sign&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;g&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;clamp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;g&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;abs&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;lam&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nb"&gt;min&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zero_grad&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;loss_fn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;backward&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
    &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;grad&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;grad&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;soft_threshold&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;grad&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;That is the entire idea. Anything you can express as a function of a tensor, you can apply to the gradient: clipping, smoothing, denoising, sign-based updates, masking. The optimizer never knows the difference.&lt;/p&gt;

&lt;h2&gt;
  
  
  Step 3: make it reusable and optimizer-agnostic
&lt;/h2&gt;

&lt;p&gt;Editing the training loop by hand gets messy. Wrap it instead, so the transform works with any first-order optimizer without changing your loop:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;GradTransform&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;transform&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;optimizer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;optimizer&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;transform&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;transform&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;zero_grad&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="n"&gt;args&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="n"&gt;kwargs&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zero_grad&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="n"&gt;args&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="n"&gt;kwargs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="n"&gt;args&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="n"&gt;kwargs&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;group&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;param_groups&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;group&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;params&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;]:&lt;/span&gt;
                &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;grad&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
                    &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;grad&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;transform&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;grad&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="n"&gt;args&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="n"&gt;kwargs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Usage is a drop-in:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;base&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;optim&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Adam&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="n"&gt;lr&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e-3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;optimizer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;GradTransform&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;base&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;soft_threshold&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# training loop is unchanged
&lt;/span&gt;&lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zero_grad&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;loss_fn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;backward&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Swap &lt;code&gt;Adam&lt;/code&gt; for &lt;code&gt;SGD&lt;/code&gt; and it still works, because the wrapper only touches &lt;code&gt;.grad&lt;/code&gt;, which every first-order optimizer reads the same way.&lt;/p&gt;

&lt;h2&gt;
  
  
  A quicker alternative: tensor hooks
&lt;/h2&gt;

&lt;p&gt;If you want the transform to fire automatically during the backward pass instead of after it, register a hook directly on a parameter:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
    &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;register_hook&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="k"&gt;lambda&lt;/span&gt; &lt;span class="n"&gt;g&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="nf"&gt;soft_threshold&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;g&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The hook runs as the gradient is computed. The wrapper approach is usually easier to reason about because everything happens in one obvious place, but hooks are handy when you want the change to be invisible to the rest of your code.&lt;/p&gt;

&lt;h2&gt;
  
  
  One honest warning
&lt;/h2&gt;

&lt;p&gt;Intercepting gradients is powerful, which means it is also a good way to quietly break training. A transform that helps on a noisy problem can actively hurt on a clean one, where the gradient signal was fine to begin with. So treat any gradient transform as a hypothesis, not a free win: run it against an untouched baseline on your actual task, with the same seeds, and keep it only if the numbers say so. The hook is easy. Earning the improvement is the hard part.&lt;/p&gt;

&lt;p&gt;If you want to see this idea taken all the way, soft-thresholding the gradient is exactly the building block behind WaveGuard, a gradient denoiser I built that swaps the flat threshold for a Haar wavelet transform and gates it so it stays quiet when the gradient is already clean. I benchmarked it honestly, including a task where it actively made things worse. &lt;/p&gt;

&lt;p&gt;Write-up here: &lt;a href="https://medium.com/@chukwudieke61/adam-cant-hear-the-signal-through-the-noise-waveguard-can-ffb1d8963a38" rel="noopener noreferrer"&gt;https://medium.com/@chukwudieke61/adam-cant-hear-the-signal-through-the-noise-waveguard-can-ffb1d8963a38&lt;/a&gt; &lt;/p&gt;

&lt;p&gt;Code here: &lt;a href="https://github.com/Harry-Potter20/wavelet-grad" rel="noopener noreferrer"&gt;https://github.com/Harry-Potter20/wavelet-grad&lt;/a&gt;.&lt;/p&gt;

</description>
      <category>deeplearning</category>
      <category>machinelearning</category>
      <category>python</category>
      <category>tutorial</category>
    </item>
  </channel>
</rss>
