<?xml version="1.0" encoding="UTF-8"?>
<rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom" xmlns:dc="http://purl.org/dc/elements/1.1/">
  <channel>
    <title>DEV Community: Berkan Sesen</title>
    <description>The latest articles on DEV Community by Berkan Sesen (@berkan_sesen).</description>
    <link>https://dev.to/berkan_sesen</link>
    <image>
      <url>https://media2.dev.to/dynamic/image/width=90,height=90,fit=cover,gravity=auto,format=auto/https:%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Fuser%2Fprofile_image%2F3843317%2F81217b17-b750-4b21-a7f8-d6dbafdcf816.jpg</url>
      <title>DEV Community: Berkan Sesen</title>
      <link>https://dev.to/berkan_sesen</link>
    </image>
    <atom:link rel="self" type="application/rss+xml" href="https://dev.to/feed/berkan_sesen"/>
    <language>en</language>
    <item>
      <title>Value Iteration vs Q-Learning: Dynamic Programming Meets RL</title>
      <dc:creator>Berkan Sesen</dc:creator>
      <pubDate>Mon, 04 May 2026 13:07:14 +0000</pubDate>
      <link>https://dev.to/berkan_sesen/value-iteration-vs-q-learning-dynamic-programming-meets-rl-3b3a</link>
      <guid>https://dev.to/berkan_sesen/value-iteration-vs-q-learning-dynamic-programming-meets-rl-3b3a</guid>
      <description>&lt;p&gt;You have a map of the frozen lake. Every crack in the ice, every slippery patch, every hole is marked. You can sit at your desk and plan the perfect route before stepping foot on the ice. That is value iteration.&lt;/p&gt;

&lt;p&gt;Now imagine you have no map. You lace up your boots and start walking. You slip, you fall into holes, you backtrack. But each time you learn a little more about which moves pay off and which ones do not. That is Q-learning.&lt;/p&gt;

&lt;p&gt;Both approaches solve the same problem (finding the best policy in a Markov Decision Process), but they start from radically different assumptions about what you know. In &lt;a href="https://sesen.ai/blog/q-learning-frozen-lake-from-scratch" rel="noopener noreferrer"&gt;our earlier Q-learning post&lt;/a&gt;, we focused purely on the model-free approach. This post puts the two side by side on the same FrozenLake environment, so you can see exactly what a model buys you, and what you give up when you do not have one.&lt;/p&gt;

&lt;p&gt;By the end of this post, you will have implemented both value iteration and Q-learning from scratch, compared their convergence and policies head-to-head, and understood the Bellman equation that underpins them both.&lt;/p&gt;

&lt;h2&gt;
  
  
  Quick Win: Run Both Algorithms
&lt;/h2&gt;

&lt;p&gt;Let's see both algorithms in action. Click the badge to open the interactive notebook:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://colab.research.google.com/github/zhubarb/sesen_ai_ml_tutorials/blob/main/notebooks/reinforcement-learning/value_iteration_vs_q_learning.ipynb" rel="noopener noreferrer"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fcolab.research.google.com%2Fassets%2Fcolab-badge.svg" alt="Open In Colab" width="117" height="20"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Watch value iteration discover optimal state values in just a few sweeps, with "heat" radiating outward from the goal:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F8wzvlim0b9hdiq883zfn.gif" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F8wzvlim0b9hdiq883zfn.gif" alt="Value iteration evolving state values over sweeps, with values radiating outward from the goal state as the algorithm converges." width="600" height="600"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Here is the complete implementation for both methods:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;gymnasium&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;

&lt;span class="c1"&gt;# ── Value Iteration (model-based) ──────────────────────────
&lt;/span&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;value_iteration&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;gamma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.95&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;theta&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e-8&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;Compute optimal V* using the Bellman optimality equation.&lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
    &lt;span class="n"&gt;nS&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;observation_space&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;
    &lt;span class="n"&gt;nA&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;action_space&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;
    &lt;span class="n"&gt;V&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nS&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;while&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="n"&gt;delta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
        &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;s&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nS&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
            &lt;span class="n"&gt;action_values&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nA&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;a&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nA&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
                &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;prob&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;next_s&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;done&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unwrapped&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;P&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;]:&lt;/span&gt;
                    &lt;span class="n"&gt;action_values&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;prob&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;reward&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;gamma&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;V&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;next_s&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
            &lt;span class="n"&gt;best_value&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;action_values&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="n"&gt;delta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;delta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nf"&gt;abs&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;best_value&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;V&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="p"&gt;]))&lt;/span&gt;
            &lt;span class="n"&gt;V&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;best_value&lt;/span&gt;
        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;delta&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;&lt;/span&gt; &lt;span class="n"&gt;theta&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="k"&gt;break&lt;/span&gt;

    &lt;span class="c1"&gt;# Extract greedy policy from V*
&lt;/span&gt;    &lt;span class="n"&gt;policy&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nS&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;s&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nS&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;action_values&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nA&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;a&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nA&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
            &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;prob&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;next_s&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;done&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unwrapped&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;P&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;]:&lt;/span&gt;
                &lt;span class="n"&gt;action_values&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;prob&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;reward&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;gamma&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;V&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;next_s&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
        &lt;span class="n"&gt;policy&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;argmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;action_values&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;V&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;policy&lt;/span&gt;

&lt;span class="c1"&gt;# ── Q-Learning (model-free) ────────────────────────────────
&lt;/span&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;q_learning&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_episodes&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;10000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;gamma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.95&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;lr&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
               &lt;span class="n"&gt;epsilon_start&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;epsilon_end&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.01&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;decay_rate&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;7e-3&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;Tabular Q-learning with epsilon-greedy exploration.&lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
    &lt;span class="n"&gt;nS&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;observation_space&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;
    &lt;span class="n"&gt;nA&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;action_space&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;
    &lt;span class="n"&gt;Q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;nS&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;nA&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
    &lt;span class="n"&gt;rewards&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
    &lt;span class="n"&gt;epsilon&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;epsilon_start&lt;/span&gt;

    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;ep&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_episodes&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;reset&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;t&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
            &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;uniform&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;=&lt;/span&gt; &lt;span class="n"&gt;epsilon&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
                &lt;span class="n"&gt;action&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;action_space&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sample&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
            &lt;span class="k"&gt;else&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
                &lt;span class="n"&gt;action&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;argmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Q&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;:])&lt;/span&gt;

            &lt;span class="n"&gt;next_state&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="c1"&gt;# Q-learning update
&lt;/span&gt;            &lt;span class="n"&gt;Q&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;lr&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;
                &lt;span class="n"&gt;reward&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;gamma&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Q&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;next_state&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;:])&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;Q&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
            &lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="n"&gt;state&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;next_state&lt;/span&gt;
            &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt; &lt;span class="ow"&gt;or&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
                &lt;span class="n"&gt;rewards&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
                &lt;span class="n"&gt;epsilon&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;epsilon_end&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;epsilon_start&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;epsilon_end&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;decay_rate&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;ep&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
                &lt;span class="k"&gt;break&lt;/span&gt;

    &lt;span class="n"&gt;policy&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;argmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;Q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;policy&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;rewards&lt;/span&gt;

&lt;span class="c1"&gt;# ── Run both on FrozenLake ─────────────────────────────────
&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;make&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;FrozenLake-v1&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;is_slippery&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;V_star&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;vi_policy&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;value_iteration&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;gamma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.95&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;Q_star&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ql_policy&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ql_rewards&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;q_learning&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_episodes&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;10000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;gamma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.95&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;strong&gt;The result:&lt;/strong&gt; Value iteration converges in 184 sweeps and produces a policy that succeeds ~73% of the time. Q-learning, after 10,000 episodes of trial and error, learns a policy that also achieves ~73% success, and agrees with the VI policy on 14 out of 16 states. Both methods find near-identical strategies, but through very different paths.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fr71y8gyfne8bqvrv9lsj.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fr71y8gyfne8bqvrv9lsj.webp" alt="Side-by-side comparison of learned policies, with arrows showing the greedy action in each state and colour showing the state value. Both methods converge to nearly identical policies." width="800" height="414"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h2&gt;
  
  
  What Just Happened?
&lt;/h2&gt;

&lt;p&gt;Both algorithms answer the same question: "What is the best action in every state?" But they go about it in fundamentally different ways.&lt;/p&gt;

&lt;h3&gt;
  
  
  Value Iteration: Planning with a Blueprint
&lt;/h3&gt;

&lt;p&gt;Value iteration has access to the environment's full transition model &lt;code&gt;$P(s' \mid s, a)$&lt;/code&gt;. This is the complete blueprint: for every state and action, you know exactly which states you might land in and with what probability.&lt;/p&gt;

&lt;p&gt;The algorithm sweeps through every state, computing the value of the best action using the &lt;strong&gt;Bellman optimality equation&lt;/strong&gt;:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DV%28s%29%2520%255Cleftarrow%2520%255Cmax_a%2520%255Csum_%257Bs%27%257D%2520P%28s%27%2520%255Cmid%2520s%252C%2520a%29%2520%255Cleft%255B%2520R%28s%252C%2520a%252C%2520s%27%29%2520%252B%2520%255Cgamma%2520%255C%252C%2520V%28s%27%29%2520%255Cright%255D" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DV%28s%29%2520%255Cleftarrow%2520%255Cmax_a%2520%255Csum_%257Bs%27%257D%2520P%28s%27%2520%255Cmid%2520s%252C%2520a%29%2520%255Cleft%255B%2520R%28s%252C%2520a%252C%2520s%27%29%2520%252B%2520%255Cgamma%2520%255C%252C%2520V%28s%27%29%2520%255Cright%255D" alt="equation" width="509" height="54"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Each sweep propagates value information one step further from the goal. In the GIF above, you can see this: after sweep 0, only the state next to the goal has any value (0.333). By sweep 5, the values have spread across the grid. By sweep 100, they have stabilised.&lt;/p&gt;

&lt;p&gt;The key line in the code is this inner loop:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;prob&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;next_s&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;done&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unwrapped&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;P&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;]:&lt;/span&gt;
    &lt;span class="n"&gt;action_values&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;prob&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;reward&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;gamma&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;V&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;next_s&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This sums over all possible outcomes of taking action &lt;code&gt;$a$&lt;/code&gt; from state &lt;code&gt;$s$&lt;/code&gt;, weighting each by its transition probability. No randomness, no sampling; it is a deterministic computation over the full model.&lt;/p&gt;

&lt;h3&gt;
  
  
  Q-Learning: Learning by Doing
&lt;/h3&gt;

&lt;p&gt;Q-learning has no access to the transition model. It learns by interacting with the environment, collecting &lt;code&gt;$(s, a, r, s')$&lt;/code&gt; tuples, and updating its Q-table one experience at a time:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;Q&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;lr&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;reward&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;gamma&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Q&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;next_state&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;:])&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;Q&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This is the &lt;strong&gt;temporal difference (TD) update&lt;/strong&gt;. The term &lt;code&gt;$r + \gamma \max_{a'} Q(s', a')$&lt;/code&gt; is the &lt;strong&gt;TD target&lt;/strong&gt;: what the agent thinks the return should be based on the immediate reward plus the estimated future value. The difference between this target and the current estimate &lt;code&gt;$Q(s, a)$&lt;/code&gt; is the &lt;strong&gt;TD error&lt;/strong&gt;, which drives learning.&lt;/p&gt;

&lt;p&gt;Because Q-learning relies on sampled experience rather than exhaustive computation, it needs many more interactions (10,000 episodes vs 184 sweeps). It also needs an exploration strategy (epsilon-greedy) to ensure it visits enough state-action pairs to build an accurate Q-table. If you have already read our &lt;a href="https://sesen.ai/blog/q-learning-frozen-lake-from-scratch" rel="noopener noreferrer"&gt;Q-learning tutorial&lt;/a&gt;, these mechanics will be familiar.&lt;/p&gt;

&lt;h3&gt;
  
  
  Why Both Reach the Same Answer
&lt;/h3&gt;

&lt;p&gt;This is not a coincidence. Both algorithms are solving the same &lt;strong&gt;Bellman optimality equation&lt;/strong&gt;. Value iteration solves it through repeated full sweeps over the state space. Q-learning solves it through stochastic approximation: each sampled experience nudges the Q-values toward the true solution, one step at a time.&lt;/p&gt;

&lt;p&gt;Given enough sweeps, value iteration converges exactly. Given enough episodes, Q-learning converges asymptotically (with probability 1, under mild conditions on the learning rate and exploration). On FrozenLake, both methods produce policies that agree on 14 of 16 states and achieve the same ~73% success rate.&lt;/p&gt;

&lt;h2&gt;
  
  
  Going Deeper
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Why Not 100%? The Stochasticity Tax
&lt;/h3&gt;

&lt;p&gt;Even the optimal policy only succeeds about 73% of the time on slippery FrozenLake. This is not a bug in the algorithm. The environment is genuinely stochastic: each action has only a 1/3 chance of going in the intended direction, with 1/3 probability of sliding in each perpendicular direction. Some starting positions are simply doomed to fail because all paths to the goal pass near holes, and the ice will occasionally slide you in.&lt;/p&gt;

&lt;h3&gt;
  
  
  Convergence: 184 Sweeps vs 10,000 Episodes
&lt;/h3&gt;

&lt;p&gt;Value iteration converges to the exact solution in 184 sweeps:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fo21x3iw67yuvgm5y6bym.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fo21x3iw67yuvgm5y6bym.webp" alt="Value iteration Bellman error drops exponentially, converging to the threshold in 184 iterations." width="800" height="449"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The Bellman error (maximum change in any state value) decreases exponentially. This is because value iteration is a &lt;strong&gt;contraction mapping&lt;/strong&gt;: each sweep brings V closer to the true V* by a factor of at least &lt;code&gt;$\gamma$&lt;/code&gt;. With &lt;code&gt;$\gamma = 0.95$&lt;/code&gt;, the error shrinks by at least 5% per sweep, guaranteeing convergence.&lt;/p&gt;

&lt;p&gt;Q-learning, by contrast, follows a noisier path:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Ffzja4vqcs01eani7f5af.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Ffzja4vqcs01eani7f5af.webp" alt="Q-learning success rate over 10,000 episodes. The training curve is noisy because the agent is still exploring, but the extracted policy reaches VI-level performance." width="800" height="456"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The rolling average hovers around 40-60% during training because the agent is still exploring (epsilon &amp;gt; 0). But the extracted greedy policy, evaluated after training with epsilon = 0, achieves the same 73% as value iteration.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Model-Based vs Model-Free Tradeoff
&lt;/h3&gt;

&lt;p&gt;This comparison crystallises one of the deepest tradeoffs in reinforcement learning:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fr92ecq0ygikamrniy75h.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fr92ecq0ygikamrniy75h.webp" alt="Head-to-head comparison: VI needs 184 iterations vs Q-learning's 10,000 episodes, both reach 73% success, but VI requires a model while Q-learning does not." width="800" height="282"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Property&lt;/th&gt;
&lt;th&gt;Value Iteration&lt;/th&gt;
&lt;th&gt;Q-Learning&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;Needs transition model?&lt;/td&gt;
&lt;td&gt;Yes (env.P)&lt;/td&gt;
&lt;td&gt;No&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Steps to converge&lt;/td&gt;
&lt;td&gt;184 sweeps&lt;/td&gt;
&lt;td&gt;~10,000 episodes&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Optimality guarantee&lt;/td&gt;
&lt;td&gt;Exact&lt;/td&gt;
&lt;td&gt;Asymptotic&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Works for unknown environments?&lt;/td&gt;
&lt;td&gt;No&lt;/td&gt;
&lt;td&gt;Yes&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Memory&lt;/td&gt;
&lt;td&gt;O(|S|)&lt;/td&gt;
&lt;td&gt;O(|S| × |A|)&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;p&gt;Value iteration is faster and guarantees exact optimality, but it requires something that is rarely available in practice: the full transition model &lt;code&gt;$P(s' \mid s, a)$&lt;/code&gt;. In robotics, game-playing, or any complex real-world task, you almost never have this. That is why model-free methods like Q-learning (and its deep successor, &lt;a href="https://sesen.ai/blog/deep-q-networks-experience-replay-target-networks" rel="noopener noreferrer"&gt;DQN&lt;/a&gt;) dominate modern RL.&lt;/p&gt;

&lt;h3&gt;
  
  
  Hyperparameter Sensitivity
&lt;/h3&gt;

&lt;p&gt;The original code uses a high learning rate (&lt;code&gt;$\alpha = 0.8$&lt;/code&gt;) and fast epsilon decay (&lt;code&gt;$\text{decay\_rate} = 7 \times 10^{-3}$&lt;/code&gt;). This means Q-learning explores aggressively early on and then commits to exploitation within about 1,000 episodes. The high learning rate works here because FrozenLake has a small, discrete state space. For larger problems, you would need to lower &lt;code&gt;$\alpha$&lt;/code&gt; considerably (our &lt;a href="https://sesen.ai/blog/deep-q-networks-experience-replay-target-networks" rel="noopener noreferrer"&gt;DQN post&lt;/a&gt; uses 0.001 with a neural network).&lt;/p&gt;

&lt;p&gt;Value iteration, by contrast, has no learning rate. The discount factor &lt;code&gt;$\gamma = 0.95$&lt;/code&gt; is the only tunable parameter, and it has a clear interpretation: how much to value future rewards relative to immediate ones. Higher gamma means the agent plans further ahead but converges more slowly.&lt;/p&gt;

&lt;h3&gt;
  
  
  When to Use Which
&lt;/h3&gt;

&lt;p&gt;&lt;strong&gt;Use value iteration when:&lt;/strong&gt;&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;You have a complete model of the environment (transition probabilities and rewards)&lt;/li&gt;
&lt;li&gt;The state space is small enough to sweep over exhaustively&lt;/li&gt;
&lt;li&gt;You need guaranteed optimality&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;strong&gt;Use Q-learning when:&lt;/strong&gt;&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;You can only interact with the environment through trial and error&lt;/li&gt;
&lt;li&gt;The model is unknown or too complex to specify&lt;/li&gt;
&lt;li&gt;You are willing to trade computation for generality&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;In practice, most interesting problems fall into the Q-learning camp, which is why model-free methods get so much attention. Not all model-free approaches use value functions, though. Methods like &lt;a href="https://sesen.ai/blog/cross-entropy-method-evolution-style-rl" rel="noopener noreferrer"&gt;the cross-entropy method&lt;/a&gt; and &lt;a href="https://sesen.ai/blog/simulated-annealing-cartpole" rel="noopener noreferrer"&gt;simulated annealing&lt;/a&gt; search policy space directly without ever estimating state values. But understanding value iteration is essential because it reveals the Bellman equation that underlies all value-based RL. As we saw in our &lt;a href="https://sesen.ai/blog/policy-gradients-reinforce-from-scratch" rel="noopener noreferrer"&gt;policy gradient post&lt;/a&gt;, even gradient-based methods ultimately try to maximise the same value function.&lt;/p&gt;

&lt;h2&gt;
  
  
  Deep Dive: The Papers
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Bellman's Foundation
&lt;/h3&gt;

&lt;p&gt;Value iteration traces directly to Richard Bellman's 1957 monograph &lt;em&gt;Dynamic Programming&lt;/em&gt;. Bellman introduced the &lt;strong&gt;principle of optimality&lt;/strong&gt;:&lt;/p&gt;

&lt;blockquote&gt;
&lt;p&gt;"An optimal policy has the property that whatever the initial state and initial decision are, the remaining decisions must constitute an optimal policy with regard to the state resulting from the first decision."&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;This recursive insight leads to the Bellman optimality equation. For the state-value function:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DV%255E%2A%28s%29%2520%253D%2520%255Cmax_a%2520%255Csum_%257Bs%27%257D%2520P%28s%27%2520%255Cmid%2520s%252C%2520a%29%2520%255Cleft%255B%2520R%28s%252C%2520a%252C%2520s%27%29%2520%252B%2520%255Cgamma%2520%255C%252C%2520V%255E%2A%28s%27%29%2520%255Cright%255D" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DV%255E%2A%28s%29%2520%253D%2520%255Cmax_a%2520%255Csum_%257Bs%27%257D%2520P%28s%27%2520%255Cmid%2520s%252C%2520a%29%2520%255Cleft%255B%2520R%28s%252C%2520a%252C%2520s%27%29%2520%252B%2520%255Cgamma%2520%255C%252C%2520V%255E%2A%28s%27%29%2520%255Cright%255D" alt="equation" width="523" height="54"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;And for the action-value function (the one Q-learning estimates):&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DQ%255E%2A%28s%252C%2520a%29%2520%253D%2520%255Csum_%257Bs%27%257D%2520P%28s%27%2520%255Cmid%2520s%252C%2520a%29%2520%255Cleft%255B%2520R%28s%252C%2520a%252C%2520s%27%29%2520%252B%2520%255Cgamma%2520%255Cmax_%257Ba%27%257D%2520Q%255E%2A%28s%27%252C%2520a%27%29%2520%255Cright%255D" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DQ%255E%2A%28s%252C%2520a%29%2520%253D%2520%255Csum_%257Bs%27%257D%2520P%28s%27%2520%255Cmid%2520s%252C%2520a%29%2520%255Cleft%255B%2520R%28s%252C%2520a%252C%2520s%27%29%2520%252B%2520%255Cgamma%2520%255Cmax_%257Ba%27%257D%2520Q%255E%2A%28s%27%252C%2520a%27%29%2520%255Cright%255D" alt="equation" width="583" height="60"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Value iteration simply applies the first equation as an update rule, sweeping over all states until convergence. The convergence is guaranteed because the Bellman operator is a contraction in the sup-norm with coefficient &lt;code&gt;$\gamma &amp;lt; 1$&lt;/code&gt; (proven by Bellman himself and later formalised by Denardo, 1967).&lt;/p&gt;

&lt;h3&gt;
  
  
  Watkins' Q-Learning
&lt;/h3&gt;

&lt;p&gt;Q-learning was introduced by Christopher Watkins in his 1989 PhD thesis at Cambridge, with the convergence proof published in &lt;a href="https://link.springer.com/article/10.1007/BF00992698" rel="noopener noreferrer"&gt;Watkins &amp;amp; Dayan (1992)&lt;/a&gt;. The key insight was that you can learn &lt;code&gt;$Q^*$&lt;/code&gt; directly from experience, without ever knowing the transition model:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DQ%28s%252C%2520a%29%2520%255Cleftarrow%2520Q%28s%252C%2520a%29%2520%252B%2520%255Calpha%2520%255Cleft%255B%2520r%2520%252B%2520%255Cgamma%2520%255Cmax_%257Ba%27%257D%2520Q%28s%27%252C%2520a%27%29%2520-%2520Q%28s%252C%2520a%29%2520%255Cright%255D" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DQ%28s%252C%2520a%29%2520%255Cleftarrow%2520Q%28s%252C%2520a%29%2520%252B%2520%255Calpha%2520%255Cleft%255B%2520r%2520%252B%2520%255Cgamma%2520%255Cmax_%257Ba%27%257D%2520Q%28s%27%252C%2520a%27%29%2520-%2520Q%28s%252C%2520a%29%2520%255Cright%255D" alt="equation" width="551" height="46"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Watkins &amp;amp; Dayan proved that Q-learning converges to &lt;code&gt;$Q^*$&lt;/code&gt; with probability 1, provided:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;All state-action pairs are visited infinitely often&lt;/li&gt;
&lt;li&gt;The learning rate &lt;code&gt;$\alpha$&lt;/code&gt; satisfies: &lt;code&gt;$\sum \alpha_t = \infty$&lt;/code&gt; and &lt;code&gt;$\sum \alpha_t^2 &amp;lt; \infty$&lt;/code&gt;
&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;The first condition is why we need epsilon-greedy exploration. The second is a standard stochastic approximation requirement (Robbins-Monro conditions). In practice, we use a fixed or slowly decaying learning rate and rely on the algorithm converging "well enough" rather than proving formal convergence.&lt;/p&gt;

&lt;h3&gt;
  
  
  The DP-RL Connection
&lt;/h3&gt;

&lt;p&gt;Sutton &amp;amp; Barto's &lt;em&gt;Reinforcement Learning: An Introduction&lt;/em&gt; (2nd ed., 2018) makes the connection explicit in Chapters 4 and 6. Value iteration is presented as a dynamic programming method (Chapter 4), while Q-learning is a temporal difference method (Chapter 6). The book shows that TD methods can be viewed as &lt;strong&gt;sampling-based approximations to DP&lt;/strong&gt;: where DP backs up values using the full distribution over successors, TD methods back up using a single sampled successor.&lt;/p&gt;

&lt;p&gt;This connection runs deep. Every model-free RL algorithm, from &lt;a href="https://sesen.ai/blog/q-learning-frozen-lake-from-scratch" rel="noopener noreferrer"&gt;Q-learning&lt;/a&gt; to &lt;a href="https://sesen.ai/blog/deep-q-networks-experience-replay-target-networks" rel="noopener noreferrer"&gt;DQN&lt;/a&gt; to &lt;a href="https://sesen.ai/blog/policy-gradients-reinforce-from-scratch" rel="noopener noreferrer"&gt;policy gradients&lt;/a&gt;, is implicitly solving a Bellman equation. The difference is in how they approximate the expectation: through tabular sweeps (DP), sampled transitions (TD), or complete episode returns (Monte Carlo).&lt;/p&gt;

&lt;h3&gt;
  
  
  Further Reading
&lt;/h3&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://press.princeton.edu/books/paperback/9780691146683/dynamic-programming" rel="noopener noreferrer"&gt;Bellman, R. (1957)&lt;/a&gt;. &lt;em&gt;Dynamic Programming&lt;/em&gt;. Princeton University Press. The foundational text.&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://link.springer.com/article/10.1007/BF00992698" rel="noopener noreferrer"&gt;Watkins, C.J.C.H. &amp;amp; Dayan, P. (1992)&lt;/a&gt;. "Q-learning". &lt;em&gt;Machine Learning&lt;/em&gt;, 8, 279-292. The convergence proof.&lt;/li&gt;
&lt;li&gt;
&lt;a href="http://incompleteideas.net/book/the-book.html" rel="noopener noreferrer"&gt;Sutton, R.S. &amp;amp; Barto, A.G. (2018)&lt;/a&gt;. &lt;em&gt;Reinforcement Learning: An Introduction&lt;/em&gt;. 2nd edition. Free online. Chapters 4 (DP) and 6 (TD) are directly relevant.&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://onlinelibrary.wiley.com/doi/book/10.1002/9780470316887" rel="noopener noreferrer"&gt;Puterman, M.L. (2014)&lt;/a&gt;. &lt;em&gt;Markov Decision Processes&lt;/em&gt;. The definitive theoretical reference.&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Try It Yourself
&lt;/h2&gt;

&lt;p&gt;The &lt;a href="https://colab.research.google.com/github/zhubarb/sesen_ai_ml_tutorials/blob/main/notebooks/reinforcement-learning/value_iteration_vs_q_learning.ipynb" rel="noopener noreferrer"&gt;interactive notebook&lt;/a&gt; includes exercises:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Non-slippery mode&lt;/strong&gt;: Set &lt;code&gt;is_slippery=False&lt;/code&gt; and compare. Both methods should now achieve ~100% success. How does this change the convergence speed?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;8x8 grid&lt;/strong&gt;: Try &lt;code&gt;FrozenLake8x8-v1&lt;/code&gt;. Value iteration still works perfectly. How does Q-learning cope with the larger state space?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Learned transition model&lt;/strong&gt;: The original code includes a &lt;code&gt;learn_trans_matrix()&lt;/code&gt; function that estimates &lt;code&gt;$P(s' \mid s, a)$&lt;/code&gt; from random play, then runs VI on the learned model. Try this hybrid approach. How many random episodes do you need before the learned model matches the true one?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Discount factor sensitivity&lt;/strong&gt;: Vary &lt;code&gt;$\gamma$&lt;/code&gt; from 0.5 to 0.99 and plot the success rate for both methods. When does a low gamma hurt?&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;Understanding value iteration gives you the theoretical bedrock of RL. Understanding Q-learning gives you the practical tool that works when models are not available. Together, they frame the central tradeoff that drives all of modern reinforcement learning.&lt;/p&gt;

&lt;h2&gt;
  
  
  Interactive Tools
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/q-learning-visualizer" rel="noopener noreferrer"&gt;Q-Learning Visualiser&lt;/a&gt; — Watch value iteration and Q-learning converge on grid worlds in the browser&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Related Posts
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/q-learning-frozen-lake-from-scratch" rel="noopener noreferrer"&gt;Q-Learning on Frozen Lake from Scratch&lt;/a&gt; — Deep dive into tabular Q-learning on the same environment&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/deep-q-networks-experience-replay-target-networks" rel="noopener noreferrer"&gt;Deep Q-Networks: Experience Replay and Target Networks&lt;/a&gt; — Scaling Q-learning with neural networks&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/policy-gradients-reinforce-from-scratch" rel="noopener noreferrer"&gt;Policy Gradients: REINFORCE from Scratch&lt;/a&gt; — The policy-based alternative to value methods&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/cross-entropy-method-evolution-style-rl" rel="noopener noreferrer"&gt;Cross-Entropy Method: Evolution-Style RL&lt;/a&gt; — A gradient-free approach to the same control problems&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Frequently Asked Questions
&lt;/h2&gt;

&lt;h3&gt;
  
  
  What is the difference between value iteration and Q-learning?
&lt;/h3&gt;

&lt;p&gt;Value iteration is a dynamic programming method that requires a complete model of the environment (transition probabilities and rewards) and sweeps through all states systematically. Q-learning is model-free: it learns from experience without knowing the environment dynamics. Both converge to optimal values, but value iteration is faster when a model is available, while Q-learning works when it is not.&lt;/p&gt;

&lt;h3&gt;
  
  
  What is the Bellman equation?
&lt;/h3&gt;

&lt;p&gt;The Bellman equation expresses the value of a state as the immediate reward plus the discounted value of the next state. It is the foundation of both value iteration and Q-learning. Value iteration solves it by iterating the equation across all states until convergence. Q-learning solves it incrementally by updating one state-action pair at a time from experience.&lt;/p&gt;

&lt;h3&gt;
  
  
  When should I use dynamic programming instead of Q-learning?
&lt;/h3&gt;

&lt;p&gt;Use dynamic programming (value iteration, policy iteration) when you have a complete and accurate model of the environment. This is common in games with known rules, inventory management, and operations research. When the model is unknown, too complex, or too large to enumerate, use model-free methods like Q-learning.&lt;/p&gt;

&lt;h3&gt;
  
  
  What is the difference between value iteration and policy iteration?
&lt;/h3&gt;

&lt;p&gt;Value iteration updates the value function using the Bellman optimality equation until convergence, then extracts the policy. Policy iteration alternates between evaluating the current policy exactly and improving it greedily. Policy iteration often converges in fewer iterations but each iteration is more expensive. For small state spaces, both work well.&lt;/p&gt;

&lt;h3&gt;
  
  
  Does value iteration always converge?
&lt;/h3&gt;

&lt;p&gt;Yes, for finite MDPs with a discount factor less than 1. The Bellman operator is a contraction mapping, guaranteeing convergence at a geometric rate. The number of iterations needed depends on the discount factor (higher gamma means slower convergence) and the desired precision. In practice, convergence is usually fast for small to moderate state spaces.&lt;/p&gt;

</description>
      <category>reinforcementlearning</category>
      <category>optimisation</category>
      <category>dynamicprogramming</category>
    </item>
    <item>
      <title>Custom Likelihoods in PyMC: One-Inflated Beta Regression for Loan Repayment</title>
      <dc:creator>Berkan Sesen</dc:creator>
      <pubDate>Fri, 01 May 2026 08:47:52 +0000</pubDate>
      <link>https://dev.to/berkan_sesen/custom-likelihoods-in-pymc-one-inflated-beta-regression-for-loan-repayment-2k5k</link>
      <guid>https://dev.to/berkan_sesen/custom-likelihoods-in-pymc-one-inflated-beta-regression-for-loan-repayment-2k5k</guid>
      <description>&lt;p&gt;When a borrower takes out a personal loan, they might repay every penny, default entirely, or land anywhere in between. The interesting variable is the fraction eventually recovered: a number between 0 and 1 for each loan in the portfolio. Plot the distribution across thousands of loans and it looks like a smooth Beta curve with a tall spike bolted on at the right edge — a mass of borrowers who repaid in full.&lt;/p&gt;

&lt;p&gt;That spike is good news for the lender, but a headache for the modeller. Standard Beta regression handles continuous outcomes on (0, 1), but it cannot produce a point mass at the boundary. Logistic regression predicts a binary paid-or-not label, throwing away the partial repayment information. Neither tool fits the data you actually have.&lt;/p&gt;

&lt;p&gt;In our &lt;a href="https://sesen.ai/blog/hierarchical-bayesian-regression-pymc" rel="noopener noreferrer"&gt;first PyMC post&lt;/a&gt;, we built hierarchical models using built-in distributions. In the &lt;a href="https://sesen.ai/blog/bayesian-survival-analysis-pymc" rel="noopener noreferrer"&gt;second&lt;/a&gt;, we handled non-standard likelihoods with &lt;code&gt;pm.Potential&lt;/code&gt; for right-censored survival data.&lt;/p&gt;

&lt;p&gt;This post takes the final step: writing a piecewise log-likelihood from scratch for a mixture of continuous and discrete components. By the end, you will construct a One-Inflated Beta (OIB) regression in PyMC, hand-code the Beta log-density, and infer how borrower characteristics drive both the probability of full repayment and the expected partial repayment fraction.&lt;/p&gt;

&lt;h2&gt;
  
  
  Let's Build It
&lt;/h2&gt;

&lt;p&gt;Click the badge below to open the full interactive notebook:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://colab.research.google.com/github/zhubarb/sesen_ai_ml_tutorials/blob/main/notebooks/bayesian/one_inflated_beta_regression.ipynb" rel="noopener noreferrer"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fcolab.research.google.com%2Fassets%2Fcolab-badge.svg" alt="Open In Colab" width="117" height="20"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;We will generate synthetic loan data for 2,000 borrowers, fit an OIB regression model, and recover the true data-generating parameters.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fuqqebt5jz7pgqovdi171.gif" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fuqqebt5jz7pgqovdi171.gif" alt="Two-panel animation building up as MCMC draws accumulate. Left panel shows the predicted proportion of fully-repaid loans converging to the observed 60.7%. Right panel shows the posterior predictive Beta component gradually matching the observed partial repayment histogram." width="800" height="267"&gt;&lt;/a&gt;&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;pymc&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;pytensor.tensor&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;pt&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;arviz&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;az&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;matplotlib.pyplot&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;

&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;seed&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# --- Generate synthetic loan data ---
&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;2000&lt;/span&gt;
&lt;span class="n"&gt;credit_score&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;loan_to_value&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;interest_rate&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;income_ratio&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;X&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;column_stack&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;credit_score&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;loan_to_value&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;interest_rate&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;income_ratio&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;span class="n"&gt;feature_names&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;credit_score&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;loan_to_value&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;interest_rate&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;income_ratio&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;

&lt;span class="c1"&gt;# True parameters
&lt;/span&gt;&lt;span class="n"&gt;true_psi&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;array&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.6&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.3&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;    &lt;span class="c1"&gt;# pi coefficients
&lt;/span&gt;&lt;span class="n"&gt;true_delta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;array&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="mf"&gt;0.3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.2&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;   &lt;span class="c1"&gt;# theta coefficients
&lt;/span&gt;&lt;span class="n"&gt;true_phi&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;5.0&lt;/span&gt;                                          &lt;span class="c1"&gt;# Beta precision
&lt;/span&gt;
&lt;span class="c1"&gt;# Per-loan probability of full repayment (logistic link)
&lt;/span&gt;&lt;span class="n"&gt;logit_pi&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;true_psi&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;X&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;true_psi&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:]&lt;/span&gt;
&lt;span class="n"&gt;pi_true&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;logit_pi&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

&lt;span class="c1"&gt;# Per-loan mean partial repayment (logistic link)
&lt;/span&gt;&lt;span class="n"&gt;logit_theta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;true_delta&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;X&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;true_delta&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:]&lt;/span&gt;
&lt;span class="n"&gt;theta_true&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;logit_theta&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

&lt;span class="c1"&gt;# Beta shape parameters from mean-precision
&lt;/span&gt;&lt;span class="n"&gt;alpha_true&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;theta_true&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;true_phi&lt;/span&gt;
&lt;span class="n"&gt;beta_true&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;theta_true&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;true_phi&lt;/span&gt;

&lt;span class="c1"&gt;# Sample from the OIB mixture
&lt;/span&gt;&lt;span class="n"&gt;u&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;uniform&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;repayment&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;where&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;u&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;&lt;/span&gt; &lt;span class="n"&gt;pi_true&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;alpha_true&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta_true&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

&lt;span class="n"&gt;n_full&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;repayment&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Fully repaid: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;n_full&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;/&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; (&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;n_full&lt;/span&gt;&lt;span class="o"&gt;/&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="o"&gt;%&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;)&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Partial repayment: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;n_full&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;/&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; (&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;n_full&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;/&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="o"&gt;%&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;)&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Frjpli5x77vrceay312w1.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Frjpli5x77vrceay312w1.webp" alt="Histogram of loan repayment fractions showing a tall spike at 1.0 for 1,214 fully repaid loans and a smooth Beta-shaped distribution for 786 partial repayments between 0 and 1." width="800" height="394"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Of 2,000 loans, 1,214 (60.7%) are fully repaid and 786 (39.3%) show partial repayment. The histogram immediately reveals the two populations: a tall spike at 1.0 and a continuous spread below it. No single standard distribution can capture both. Now let's build the OIB model.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Split observations by type
&lt;/span&gt;&lt;span class="n"&gt;full_idx&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;where&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;repayment&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;)[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;span class="n"&gt;partial_idx&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;where&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;repayment&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;&lt;/span&gt; &lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;)[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;span class="n"&gt;partial_values&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;repayment&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;partial_idx&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;

&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Model&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;oib_model&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="c1"&gt;# Pi sub-model: probability of full repayment (logistic link)
&lt;/span&gt;    &lt;span class="n"&gt;psi_intercept&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;psi_intercept&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;psi_coeffs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;psi_coeffs&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;logit_pi&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;psi_intercept&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;dot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;psi_coeffs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;pi&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Deterministic&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;pi&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;invlogit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;logit_pi&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

    &lt;span class="c1"&gt;# Theta sub-model: mean of partial repayment Beta (logistic link)
&lt;/span&gt;    &lt;span class="n"&gt;delta_intercept&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;delta_intercept&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;delta_coeffs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;delta_coeffs&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;logit_theta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;delta_intercept&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;dot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;delta_coeffs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;theta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Deterministic&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;theta&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;invlogit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;logit_theta&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

    &lt;span class="c1"&gt;# Phi: Beta precision (shared across all loans)
&lt;/span&gt;    &lt;span class="n"&gt;phi&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Gamma&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;phi&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Convert mean-precision to standard Beta parameters
&lt;/span&gt;    &lt;span class="n"&gt;a&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;theta&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;phi&lt;/span&gt;
    &lt;span class="n"&gt;b&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;theta&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;phi&lt;/span&gt;

    &lt;span class="c1"&gt;# Expected repayment: E[Y] = pi + (1 - pi) * theta
&lt;/span&gt;    &lt;span class="n"&gt;E_f&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Deterministic&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;E_f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pi&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;pi&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;theta&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# --- Piecewise log-likelihood via pm.Potential ---
&lt;/span&gt;    &lt;span class="c1"&gt;# Fully repaid loans: log(pi_i)
&lt;/span&gt;    &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Potential&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;ll_full&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pi&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;full_idx&lt;/span&gt;&lt;span class="p"&gt;])))&lt;/span&gt;

    &lt;span class="c1"&gt;# Partial repayments: log(1 - pi_i) + log Beta(y_i | a_i, b_i)
&lt;/span&gt;    &lt;span class="n"&gt;pa&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pb&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;partial_idx&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;partial_idx&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="n"&gt;beta_logp&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;gammaln&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pa&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;pb&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;gammaln&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pa&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;gammaln&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pb&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
                 &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pa&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;partial_values&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
                 &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pb&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;partial_values&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
    &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Potential&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;ll_partial&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;pi&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;partial_idx&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;beta_logp&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;oib_model&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;trace&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;draws&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tune&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;3000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;chains&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                      &lt;span class="n"&gt;target_accept&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.95&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_seed&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                      &lt;span class="n"&gt;init&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;jitter+adapt_diag&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;az&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;summary&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;trace&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;var_names&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;psi_intercept&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;psi_coeffs&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                                    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;delta_intercept&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;delta_coeffs&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;phi&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F8wtdk4i3mnhuap2jacol.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F8wtdk4i3mnhuap2jacol.webp" alt="Trace plots for the OIB model showing posterior distributions and MCMC chains for psi_intercept, psi_coeffs, delta_intercept, delta_coeffs, and phi, all exhibiting good mixing and convergence with zero divergences." width="800" height="702"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The trace plots show healthy chains: zero divergences, good mixing across all four chains, and unimodal posteriors centred near the true parameter values. Sampling 4,000 draws per chain with the Potential-based likelihood took about 6 seconds.&lt;/p&gt;

&lt;p&gt;You just fitted a custom Bayesian mixture model with 11 free parameters. Now let's understand how each piece works.&lt;/p&gt;

&lt;h2&gt;
  
  
  What Just Happened?
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Two populations, one model
&lt;/h3&gt;

&lt;p&gt;Our data contains two distinct groups. Some borrowers repay their loan in full (repayment fraction = 1.0), and others repay partially (0 &amp;lt; fraction &amp;lt; 1). The OIB model treats this as a mixture: with probability &lt;code&gt;$\pi_i$&lt;/code&gt; the outcome is exactly 1, and with probability &lt;code&gt;$1 - \pi_i$&lt;/code&gt; it follows a Beta distribution.&lt;/p&gt;

&lt;p&gt;Both &lt;code&gt;$\pi_i$&lt;/code&gt; and the Beta mean &lt;code&gt;$\theta_i$&lt;/code&gt; vary across borrowers. A high credit score might increase both the chance of full repayment and the expected partial repayment. The model captures these relationships through separate linear predictors with logistic links, ensuring both quantities stay between 0 and 1.&lt;/p&gt;

&lt;h3&gt;
  
  
  The piecewise log-likelihood
&lt;/h3&gt;

&lt;p&gt;The OIB density is a mixture of a point mass and a continuous distribution. For observation &lt;code&gt;$y_i$&lt;/code&gt;:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Dp%28y_i%29%2520%253D%2520%255Cbegin%257Bcases%257D%2520%255Cpi_i%2520%2526%2520%255Ctext%257Bif%2520%257D%2520y_i%2520%253D%25201%2520%255C%255C%2520%281%2520-%2520%255Cpi_i%29%2520%255Ccdot%2520f_%257B%255Ctext%257BBeta%257D%257D%28y_i%2520%255Cmid%2520%255Calpha_i%252C%2520%255Cbeta_i%29%2520%2526%2520%255Ctext%257Bif%2520%257D%25200%2520%253C%2520y_i%2520%253C%25201%2520%255Cend%257Bcases%257D" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Dp%28y_i%29%2520%253D%2520%255Cbegin%257Bcases%257D%2520%255Cpi_i%2520%2526%2520%255Ctext%257Bif%2520%257D%2520y_i%2520%253D%25201%2520%255C%255C%2520%281%2520-%2520%255Cpi_i%29%2520%255Ccdot%2520f_%257B%255Ctext%257BBeta%257D%257D%28y_i%2520%255Cmid%2520%255Calpha_i%252C%2520%255Cbeta_i%29%2520%2526%2520%255Ctext%257Bif%2520%257D%25200%2520%253C%2520y_i%2520%253C%25201%2520%255Cend%257Bcases%257D" alt="equation" width="521" height="75"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Taking logs:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Clog%2520p%28y_i%29%2520%253D%2520%255Cbegin%257Bcases%257D%2520%255Clog%2520%255Cpi_i%2520%2526%2520%255Ctext%257Bif%2520%257D%2520y_i%2520%253D%25201%2520%255C%255C%2520%255Clog%281%2520-%2520%255Cpi_i%29%2520%252B%2520%255Clog%2520f_%257B%255Ctext%257BBeta%257D%257D%28y_i%2520%255Cmid%2520%255Calpha_i%252C%2520%255Cbeta_i%29%2520%2526%2520%255Ctext%257Bif%2520%257D%25200%2520%253C%2520y_i%2520%253C%25201%2520%255Cend%257Bcases%257D" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Clog%2520p%28y_i%29%2520%253D%2520%255Cbegin%257Bcases%257D%2520%255Clog%2520%255Cpi_i%2520%2526%2520%255Ctext%257Bif%2520%257D%2520y_i%2520%253D%25201%2520%255C%255C%2520%255Clog%281%2520-%2520%255Cpi_i%29%2520%252B%2520%255Clog%2520f_%257B%255Ctext%257BBeta%257D%257D%28y_i%2520%255Cmid%2520%255Calpha_i%252C%2520%255Cbeta_i%29%2520%2526%2520%255Ctext%257Bif%2520%257D%25200%2520%253C%2520y_i%2520%253C%25201%2520%255Cend%257Bcases%257D" alt="equation" width="635" height="75"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The addition in the second branch is critical: it corresponds to multiplying the mixing weight &lt;code&gt;$(1 - \pi_i)$&lt;/code&gt; by the Beta density in probability space. A common mistake is to write multiplication of two log quantities (i.e. &lt;code&gt;log(1-pi) * log(Beta(...))&lt;/code&gt;) instead of addition. That would have no probabilistic interpretation.&lt;/p&gt;

&lt;p&gt;We implement this by splitting observations into two groups and adding each group's log-likelihood as a separate &lt;code&gt;pm.Potential&lt;/code&gt;:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Fully repaid: sum of log(pi_i) over fully-repaid loans
&lt;/span&gt;&lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Potential&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;ll_full&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pi&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;full_idx&lt;/span&gt;&lt;span class="p"&gt;])))&lt;/span&gt;

&lt;span class="c1"&gt;# Partial: sum of log(1 - pi_i) + Beta_logpdf(y_i) over partial loans
&lt;/span&gt;&lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Potential&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;ll_partial&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;pi&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;partial_idx&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;beta_logp&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This pattern should feel familiar from &lt;a href="https://sesen.ai/blog/bayesian-survival-analysis-pymc" rel="noopener noreferrer"&gt;Post 21&lt;/a&gt;, where we used &lt;code&gt;pm.Potential&lt;/code&gt; to handle right-censored observations. The principle is the same: when your likelihood has distinct branches for different observation types, split them into separate Potential terms.&lt;/p&gt;

&lt;h3&gt;
  
  
  Hand-coding the Beta log-density
&lt;/h3&gt;

&lt;p&gt;Rather than relying on &lt;code&gt;pm.logp(pm.Beta.dist(...), value)&lt;/code&gt;, we compute the Beta log-density directly using the gamma function:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;beta_logp&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;gammaln&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;gammaln&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;gammaln&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
             &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;b&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This follows from the Beta density formula &lt;code&gt;$f(y \mid \alpha, \beta) = \frac{\Gamma(\alpha + \beta)}{\Gamma(\alpha)\Gamma(\beta)} y^{\alpha-1}(1-y)^{\beta-1}$&lt;/code&gt;. Writing it out explicitly has two advantages: you can see exactly what the sampler is differentiating through, and you avoid potential issues with PyMC's internal distribution objects when used inside Potential expressions.&lt;/p&gt;

&lt;h3&gt;
  
  
  The model structure
&lt;/h3&gt;

&lt;p&gt;The model has three sub-components connected by link functions:&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Pi sub-model&lt;/strong&gt; controls which mixture component generates each observation. A logistic link maps the linear predictor &lt;code&gt;$\psi_0 + \psi_1 x_{\text{credit}} + \psi_2 x_{\text{ltv}} + \psi_3 x_{\text{rate}} + \psi_4 x_{\text{income}}$&lt;/code&gt; to a probability. Positive &lt;code&gt;$\psi_1$&lt;/code&gt; means higher credit scores increase the chance of full repayment.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Theta sub-model&lt;/strong&gt; sets the mean of the Beta distribution for partial repayments, also through a logistic link with its own coefficients &lt;code&gt;$\delta_0, \ldots, \delta_4$&lt;/code&gt;. This captures a subtlety that pure classification misses: among borrowers who do not fully repay, some covariates still push the partial fraction higher.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Phi&lt;/strong&gt; is a single shared precision parameter for the Beta component. Higher phi means less variance in partial repayments. It uses a &lt;code&gt;$\text{Gamma}(2, 0.5)$&lt;/code&gt; prior with mean 4, which favours moderate precision values.&lt;/p&gt;

&lt;h3&gt;
  
  
  Checking the fit
&lt;/h3&gt;

&lt;p&gt;Let's compare the estimated coefficients to the true values we used to generate the data.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;summary&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;az&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;summary&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;trace&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;var_names&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;psi_intercept&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;psi_coeffs&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                                        &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;delta_intercept&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;delta_coeffs&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;phi&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;

&lt;span class="n"&gt;true_vals&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;concatenate&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;true_psi&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;true_delta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;true_phi&lt;/span&gt;&lt;span class="p"&gt;]])&lt;/span&gt;
&lt;span class="n"&gt;param_names&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;psi_intercept&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;psi_coeffs[&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;]&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;)]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt;
               &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;delta_intercept&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;delta_coeffs[&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;]&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;)]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;phi&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;

&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;subplots&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;y_pos&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;arange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;param_names&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;means&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;summary&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;mean&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt;
&lt;span class="n"&gt;hdi_low&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;summary&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;hdi_3%&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt;
&lt;span class="n"&gt;hdi_high&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;summary&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;hdi_97%&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt;

&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;errorbar&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;means&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_pos&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;xerr&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;means&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;hdi_low&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hdi_high&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;means&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
            &lt;span class="n"&gt;fmt&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;o&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;steelblue&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;capsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Posterior (94% HDI)&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;scatter&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;true_vals&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_pos&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;marker&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;x&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;crimson&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;80&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
           &lt;span class="n"&gt;zorder&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;True value&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_yticks&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y_pos&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_yticklabels&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;param_names&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;axvline&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;gray&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;linestyle&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;--&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_xlabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Value&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;legend&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;loc&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;lower right&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_title&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Parameter Recovery: Posterior vs True Values&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tight_layout&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fkwm5v25kdcuv0vf13111.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fkwm5v25kdcuv0vf13111.webp" alt="Forest plot showing posterior means with 94% HDI intervals for all 11 model parameters alongside the true values used for data generation, demonstrating accurate parameter recovery across both the pi and theta sub-models." width="800" height="596"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Every true value falls within its 94% highest density interval. The model correctly identifies that credit score has the strongest positive effect on full repayment (psi_coeffs[0] = 0.85, true: 0.8), while loan-to-value ratio is the strongest negative predictor (psi_coeffs[1] = -0.58, true: -0.6). The precision parameter phi is recovered at 5.47 (true: 5.0), and the effective sample sizes all exceed 2,500.&lt;/p&gt;

&lt;h3&gt;
  
  
  Posterior predictive check
&lt;/h3&gt;

&lt;p&gt;The ultimate test: can the model reproduce the observed data distribution, including the spike at 1.0? Since we used &lt;code&gt;pm.Potential&lt;/code&gt; rather than an observed distribution, we generate predictive samples manually from the posterior:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Extract posterior samples
&lt;/span&gt;&lt;span class="n"&gt;psi_int_post&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;trace&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;posterior&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;psi_intercept&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;flatten&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;psi_coeff_post&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;trace&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;posterior&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;psi_coeffs&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;delta_int_post&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;trace&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;posterior&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;delta_intercept&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;flatten&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;delta_coeff_post&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;trace&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;posterior&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;delta_coeffs&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;phi_post&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;trace&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;posterior&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;phi&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;flatten&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

&lt;span class="n"&gt;rng&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;default_rng&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;n_draws&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;500&lt;/span&gt;
&lt;span class="n"&gt;ppc_samples&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;n_draws&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_draws&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;lp&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;psi_int_post&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;X&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;psi_coeff_post&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="n"&gt;pi_i&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;lp&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
    &lt;span class="n"&gt;lt&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;delta_int_post&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;X&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;delta_coeff_post&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="n"&gt;theta_i&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;lt&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
    &lt;span class="n"&gt;a_i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;b_i&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;theta_i&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;phi_post&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;theta_i&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;phi_post&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="n"&gt;u_i&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;rng&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;uniform&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;ppc_samples&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;where&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;u_i&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;&lt;/span&gt; &lt;span class="n"&gt;pi_i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;rng&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;a_i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;b_i&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;subplots&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;hist&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;repayment&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;bins&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;50&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;density&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;steelblue&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Observed&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;edgecolor&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;white&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;hist&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;ppc_samples&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;flatten&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="n"&gt;bins&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;50&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;density&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;coral&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Posterior predictive&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;edgecolor&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;white&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_xlabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Repayment Fraction&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_ylabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Density&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;legend&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_title&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Posterior Predictive Check&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tight_layout&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Falj4mzje5lrej3mucg09.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Falj4mzje5lrej3mucg09.webp" alt="Two-panel posterior predictive check. Left: observed vs predicted proportion of fully-repaid loans (both around 60.7%). Right: observed vs predicted density of partial repayments, showing the Beta component accurately captures the continuous distribution shape." width="800" height="277"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The posterior predictive distribution matches both the spike at 1.0 and the shape of the partial repayment component. This is something neither pure Beta regression nor logistic regression can achieve.&lt;/p&gt;

&lt;h2&gt;
  
  
  Going Deeper
&lt;/h2&gt;

&lt;h3&gt;
  
  
  The mean-precision parameterisation
&lt;/h3&gt;

&lt;p&gt;The standard Beta distribution uses shape parameters &lt;code&gt;$\alpha$&lt;/code&gt; and &lt;code&gt;$\beta$&lt;/code&gt;, but these are difficult to interpret. A borrower with &lt;code&gt;$\alpha = 2.8$&lt;/code&gt; and &lt;code&gt;$\beta = 2.1$&lt;/code&gt; tells you almost nothing at a glance. The mean-precision reparameterisation solves this:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Cmu%2520%253D%2520%255Cfrac%257B%255Calpha%257D%257B%255Calpha%2520%252B%2520%255Cbeta%257D%252C%2520%255Cquad%2520%255Cphi%2520%253D%2520%255Calpha%2520%252B%2520%255Cbeta" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Cmu%2520%253D%2520%255Cfrac%257B%255Calpha%257D%257B%255Calpha%2520%252B%2520%255Cbeta%257D%252C%2520%255Cquad%2520%255Cphi%2520%253D%2520%255Calpha%2520%252B%2520%255Cbeta" alt="equation" width="257" height="50"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Now &lt;code&gt;$\mu$&lt;/code&gt; is the mean of the distribution (the expected partial repayment fraction) and &lt;code&gt;$\phi$&lt;/code&gt; is the precision (higher means less spread). The inverse mapping recovers the standard parameters:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Calpha%2520%253D%2520%255Cmu%2520%255Cphi%252C%2520%255Cquad%2520%255Cbeta%2520%253D%2520%281%2520-%2520%255Cmu%29%2520%255Cphi" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Calpha%2520%253D%2520%255Cmu%2520%255Cphi%252C%2520%255Cquad%2520%255Cbeta%2520%253D%2520%281%2520-%2520%255Cmu%29%2520%255Cphi" alt="equation" width="251" height="25"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;In our model, &lt;code&gt;$\mu$&lt;/code&gt; is called &lt;code&gt;$\theta$&lt;/code&gt; and depends on covariates through a logistic link. The precision &lt;code&gt;$\phi$&lt;/code&gt; is shared across all observations, which assumes that the variance of partial repayments (given the mean) is the same for all borrowers. This is a simplification; a fully heteroscedastic model would give &lt;code&gt;$\phi$&lt;/code&gt; its own linear predictor.&lt;/p&gt;

&lt;h3&gt;
  
  
  Why logistic links?
&lt;/h3&gt;

&lt;p&gt;Both &lt;code&gt;$\pi$&lt;/code&gt; (probability of full repayment) and &lt;code&gt;$\theta$&lt;/code&gt; (mean of the Beta) must live in (0, 1). The logistic function &lt;code&gt;$\sigma(x) = 1 / (1 + e^{-x})$&lt;/code&gt; maps any real-valued linear predictor to this interval. This is the same link function used in &lt;a href="https://sesen.ai/blog/from-mle-to-bayesian-inference" rel="noopener noreferrer"&gt;logistic regression and Bayesian classification&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;The priors reflect the link: &lt;code&gt;$\text{Normal}(0, 5)$&lt;/code&gt; on the intercepts allows the baseline probability to range widely, while &lt;code&gt;$\text{Normal}(0, 1)$&lt;/code&gt; on the slope coefficients gently regularises each covariate's effect. On the logistic scale, a coefficient of 1.0 roughly doubles the odds, so a &lt;code&gt;$\text{Normal}(0, 1)$&lt;/code&gt; prior is mildly informative.&lt;/p&gt;

&lt;h3&gt;
  
  
  The expected value formula
&lt;/h3&gt;

&lt;p&gt;The overall expected repayment for borrower &lt;code&gt;$i$&lt;/code&gt; combines both components:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Cmathbb%257BE%257D%255BY_i%255D%2520%253D%2520%255Cpi_i%2520%255Ccdot%25201%2520%252B%2520%281%2520-%2520%255Cpi_i%29%2520%255Ccdot%2520%255Ctheta_i%2520%253D%2520%255Cpi_i%2520%252B%2520%281%2520-%2520%255Cpi_i%29%255Ctheta_i" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Cmathbb%257BE%257D%255BY_i%255D%2520%253D%2520%255Cpi_i%2520%255Ccdot%25201%2520%252B%2520%281%2520-%2520%255Cpi_i%29%2520%255Ccdot%2520%255Ctheta_i%2520%253D%2520%255Cpi_i%2520%252B%2520%281%2520-%2520%255Cpi_i%29%255Ctheta_i" alt="equation" width="467" height="25"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;This is the &lt;code&gt;E_f&lt;/code&gt; deterministic in our model. It allows you to rank borrowers by expected repayment even when their risk profiles differ in how they fail: one borrower might have a high chance of full repayment but low partial repayment if they default, while another has a moderate chance of full repayment but high partial recovery.&lt;/p&gt;

&lt;h3&gt;
  
  
  Why pm.Potential and not pm.CustomDist?
&lt;/h3&gt;

&lt;p&gt;PyMC offers two ways to implement custom likelihoods. &lt;code&gt;pm.CustomDist&lt;/code&gt; lets you define a distribution from its &lt;code&gt;logp&lt;/code&gt; function, which would look like this for OIB:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;oib_logp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;value&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pi&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;switch&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
        &lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;eq&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;value&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
        &lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pi&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
        &lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;pi&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;logp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Beta&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;dist&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;value&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This is elegant but fragile. The &lt;code&gt;pt.switch&lt;/code&gt; operator evaluates both branches for every observation during automatic differentiation.&lt;/p&gt;

&lt;p&gt;When &lt;code&gt;value = 1.0&lt;/code&gt;, the Beta branch computes &lt;code&gt;pm.logp(Beta, 1.0)&lt;/code&gt;, which returns negative infinity (since the Beta density is zero at boundaries for &lt;code&gt;$\beta &amp;gt; 1$&lt;/code&gt;). Even though the switch selects the other branch, the gradient through the infinite branch corrupts the NUTS sampler. The result: 100% divergence rate.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F9f281ukhszqz0zw0hqt2.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F9f281ukhszqz0zw0hqt2.webp" alt="Diagram of the One-Inflated Beta model showing covariates feeding into two parallel sub-models: a logistic regression for pi (full repayment probability) and a logistic regression for theta (partial repayment mean), which combine through a piecewise likelihood with a shared precision parameter phi." width="800" height="446"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The &lt;code&gt;pm.Potential&lt;/code&gt; approach avoids this entirely. By pre-splitting observations into fully-repaid and partial groups, the Beta density is never evaluated at the boundary. This is the same pattern we used for &lt;a href="https://sesen.ai/blog/bayesian-survival-analysis-pymc" rel="noopener noreferrer"&gt;censored data in survival analysis&lt;/a&gt;: separate the observation types, compute each group's log-likelihood independently, and add them as Potential terms.&lt;/p&gt;

&lt;p&gt;The trade-off is that &lt;code&gt;pm.Potential&lt;/code&gt; does not enable &lt;code&gt;pm.sample_posterior_predictive&lt;/code&gt; out of the box (you need to write manual prediction code, as we did). For many production workflows, that is a minor inconvenience compared to the reliability gain.&lt;/p&gt;

&lt;h3&gt;
  
  
  Sampling considerations
&lt;/h3&gt;

&lt;p&gt;Our sampling configuration follows the original code that inspired this tutorial:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;3,000 tuning steps&lt;/strong&gt; with 1,000 posterior draws per chain. The long warm-up helps the NUTS sampler adapt its step size to the geometry of the piecewise likelihood.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;4 chains&lt;/strong&gt; for convergence diagnostics. With &lt;code&gt;$\hat{R}$&lt;/code&gt; and effective sample size, four chains provide reliable evidence that the sampler has explored the full posterior.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;&lt;code&gt;target_accept=0.95&lt;/code&gt;&lt;/strong&gt; raises the acceptance threshold from the default 0.8, which reduces divergences in models with sharp likelihood boundaries.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;&lt;code&gt;init='jitter+adapt_diag'&lt;/code&gt;&lt;/strong&gt; initialises each chain near the prior mean with small random perturbations. A practical note from the original code: if covariates have very different scales (e.g., one ranges from 0 to 1 while another ranges from 0 to 200), the default jitter of roughly &lt;code&gt;$\pm 1$&lt;/code&gt; can push initial coefficient values far from reasonable territory. Standardising covariates beforehand, as we did, avoids this.&lt;/li&gt;
&lt;/ul&gt;

&lt;h3&gt;
  
  
  When to use something else
&lt;/h3&gt;

&lt;p&gt;The OIB model assumes that exactly-one observations arise from a fundamentally different process than partial observations. If instead you have data with a spike at zero (e.g., insurance claims where most customers file nothing), you want a &lt;strong&gt;zero-inflated&lt;/strong&gt; model. If you have spikes at both boundaries, you need a &lt;strong&gt;zero-and-one-inflated Beta&lt;/strong&gt; (ZOIB).&lt;/p&gt;

&lt;p&gt;For data with no boundary spikes at all, standard &lt;a href="https://sesen.ai/blog/maximum-likelihood-estimation-from-scratch" rel="noopener noreferrer"&gt;Beta regression&lt;/a&gt; (via MLE or Bayesian inference) is simpler and sufficient. The extra complexity of the OIB mixture is only justified when the data genuinely contains a discrete mass at the boundary.&lt;/p&gt;

&lt;h2&gt;
  
  
  Where This Comes From
&lt;/h2&gt;

&lt;p&gt;The OIB model sits at the intersection of two lines of research: Beta regression for bounded continuous data, and inflated distributions for boundary spikes.&lt;/p&gt;

&lt;h3&gt;
  
  
  Beta regression: Ferrari and Cribari-Neto (2004)
&lt;/h3&gt;

&lt;p&gt;The foundation is the Beta regression model introduced by Silvia Ferrari and Francisco Cribari-Neto in their 2004 paper "Beta Regression for Modelling Rates and Proportions" (Journal of Applied Statistics, 27(7), 799-815). They observed that rates, proportions, and fractions appear everywhere in applied statistics, yet researchers typically transform them (logit, arcsine) and apply linear regression. This is problematic because the transformation distorts the error structure and complicates interpretation.&lt;/p&gt;

&lt;p&gt;Their key insight was to model the response directly as Beta-distributed, using the mean-precision parameterisation we adopted:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Df%28y%253B%2520%255Cmu%252C%2520%255Cphi%29%2520%253D%2520%255Cfrac%257B%255CGamma%28%255Cphi%29%257D%257B%255CGamma%28%255Cmu%255Cphi%29%255C%252C%255CGamma%28%281-%255Cmu%29%255Cphi%29%257D%255C%252C%2520y%255E%257B%255Cmu%255Cphi%2520-%25201%257D%281-y%29%255E%257B%281-%255Cmu%29%255Cphi%2520-%25201%257D" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Df%28y%253B%2520%255Cmu%252C%2520%255Cphi%29%2520%253D%2520%255Cfrac%257B%255CGamma%28%255Cphi%29%257D%257B%255CGamma%28%255Cmu%255Cphi%29%255C%252C%255CGamma%28%281-%255Cmu%29%255Cphi%29%257D%255C%252C%2520y%255E%257B%255Cmu%255Cphi%2520-%25201%257D%281-y%29%255E%257B%281-%255Cmu%29%255Cphi%2520-%25201%257D" alt="equation" width="542" height="59"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;where &lt;code&gt;$0 &amp;lt; y &amp;lt; 1$&lt;/code&gt;, &lt;code&gt;$0 &amp;lt; \mu &amp;lt; 1$&lt;/code&gt; is the mean, and &lt;code&gt;$\phi &amp;gt; 0$&lt;/code&gt; is the precision. Ferrari and Cribari-Neto showed that this is a natural exponential family model when parameterised through &lt;code&gt;$\mu$&lt;/code&gt;, and proposed a logit link for the mean:&lt;/p&gt;

&lt;blockquote&gt;
&lt;p&gt;"The proposed model is useful for situations where the variable of interest is continuous and restricted to the interval (0, 1). [...] A convenient parameterisation of the beta density in terms of the mean and a precision parameter is used."&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;Their framework supports maximum likelihood estimation, but the Bayesian extension (which we use) adds uncertainty quantification and regularisation through priors. The connection to &lt;a href="https://sesen.ai/blog/maximum-likelihood-estimation-from-scratch" rel="noopener noreferrer"&gt;MLE&lt;/a&gt; is direct: the posterior mode of our model with flat priors equals the MLE of Ferrari and Cribari-Neto's model.&lt;/p&gt;

&lt;h3&gt;
  
  
  Inflated models: Ospina and Ferrari (2010)
&lt;/h3&gt;

&lt;p&gt;The standard Beta has support on the open interval (0, 1), so it cannot assign positive probability to the boundaries 0 or 1. Raydonal Ospina and Silvia Ferrari addressed this in "Inflated Beta Distributions" (Statistical Papers, 51(1), 111-126, 2010). They defined a class of mixed continuous-discrete distributions:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Ctext%257BOIB%257D%28y%2520%255Cmid%2520%255Cpi%252C%2520%255Cmu%252C%2520%255Cphi%29%2520%253D%2520%255Cbegin%257Bcases%257D%2520%255Cpi%2520%2526%2520%255Ctext%257Bif%2520%257D%2520y%2520%253D%25201%2520%255C%255C%2520%281%2520-%2520%255Cpi%29%2520%255Ccdot%2520f_%257B%255Ctext%257BBeta%257D%257D%28y%2520%255Cmid%2520%255Cmu%252C%2520%255Cphi%29%2520%2526%2520%255Ctext%257Bif%2520%257D%25200%2520%253C%2520y%2520%253C%25201%2520%255Cend%257Bcases%257D" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Ctext%257BOIB%257D%28y%2520%255Cmid%2520%255Cpi%252C%2520%255Cmu%252C%2520%255Cphi%29%2520%253D%2520%255Cbegin%257Bcases%257D%2520%255Cpi%2520%2526%2520%255Ctext%257Bif%2520%257D%2520y%2520%253D%25201%2520%255C%255C%2520%281%2520-%2520%255Cpi%29%2520%255Ccdot%2520f_%257B%255Ctext%257BBeta%257D%257D%28y%2520%255Cmid%2520%255Cmu%252C%2520%255Cphi%29%2520%2526%2520%255Ctext%257Bif%2520%257D%25200%2520%253C%2520y%2520%253C%25201%2520%255Cend%257Bcases%257D" alt="equation" width="599" height="75"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;This is exactly the piecewise density we implemented with &lt;code&gt;pm.Potential&lt;/code&gt;. The parameter &lt;code&gt;$\pi$&lt;/code&gt; controls the inflation: the probability of observing the boundary value. Ospina and Ferrari also developed zero-inflated and zero-and-one-inflated variants for different boundary patterns.&lt;/p&gt;

&lt;blockquote&gt;
&lt;p&gt;"In many practical situations, the variable of interest is continuous in the open standard unit interval but may also assume the extreme values zero and/or one with positive probabilities. [...] We introduce a class of inflated beta distributions."&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;Their work established the theoretical properties (moments, maximum likelihood estimation, score functions) that underpin our Bayesian implementation.&lt;/p&gt;

&lt;h3&gt;
  
  
  From MLE to MCMC
&lt;/h3&gt;

&lt;p&gt;The original MLE approach estimates &lt;code&gt;$\pi$&lt;/code&gt;, &lt;code&gt;$\mu$&lt;/code&gt;, and &lt;code&gt;$\phi$&lt;/code&gt; by maximising the log-likelihood. The Bayesian version replaces optimisation with &lt;a href="https://sesen.ai/blog/mcmc-metropolis-hastings-island-hopping-guide" rel="noopener noreferrer"&gt;MCMC sampling&lt;/a&gt;, yielding full posterior distributions rather than point estimates. This is particularly valuable for the OIB model because the piecewise likelihood creates a posterior geometry that point estimates cannot capture: the uncertainty in &lt;code&gt;$\pi$&lt;/code&gt; and &lt;code&gt;$\theta$&lt;/code&gt; is correlated, and the posterior for &lt;code&gt;$\phi$&lt;/code&gt; is often skewed.&lt;/p&gt;

&lt;p&gt;Where Ferrari and Cribari-Neto derived score functions by hand, we supply the log-density components to PyMC and let the NUTS sampler handle the rest. The automatic differentiation in PyTensor computes gradients through the gammaln and log operations, enabling efficient Hamiltonian Monte Carlo.&lt;/p&gt;

&lt;h3&gt;
  
  
  Algorithm summary
&lt;/h3&gt;

&lt;p&gt;The complete OIB regression procedure:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;For each observation &lt;code&gt;$i$&lt;/code&gt;, compute &lt;code&gt;$\pi_i = \sigma(\psi_0 + \mathbf{x}_i^\top \boldsymbol{\psi})$&lt;/code&gt; (full repayment probability)&lt;/li&gt;
&lt;li&gt;Compute &lt;code&gt;$\theta_i = \sigma(\delta_0 + \mathbf{x}_i^\top \boldsymbol{\delta})$&lt;/code&gt; (partial repayment mean)&lt;/li&gt;
&lt;li&gt;Compute Beta shape parameters: &lt;code&gt;$\alpha_i = \theta_i \phi$&lt;/code&gt;, &lt;code&gt;$\beta_i = (1 - \theta_i) \phi$&lt;/code&gt;
&lt;/li&gt;
&lt;li&gt;Evaluate the piecewise log-likelihood: &lt;code&gt;$\log \pi_i$&lt;/code&gt; if &lt;code&gt;$y_i = 1$&lt;/code&gt;, else &lt;code&gt;$\log(1 - \pi_i) + \log \text{Beta}(y_i \mid \alpha_i, \beta_i)$&lt;/code&gt;
&lt;/li&gt;
&lt;li&gt;Sum across all observations and sample the posterior via NUTS&lt;/li&gt;
&lt;/ol&gt;

&lt;h3&gt;
  
  
  Further reading
&lt;/h3&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;The Beta regression paper:&lt;/strong&gt; Ferrari, S. and Cribari-Neto, F. (2004). "Beta Regression for Modelling Rates and Proportions." &lt;em&gt;Journal of Applied Statistics&lt;/em&gt;, 27(7), 799-815.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Inflated distributions:&lt;/strong&gt; Ospina, R. and Ferrari, S. (2010). "Inflated Beta Distributions." &lt;em&gt;Statistical Papers&lt;/em&gt;, 51(1), 111-126.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;The PyMC CustomDist guide:&lt;/strong&gt; &lt;a href="https://www.pymc.io/projects/docs/en/latest/api/distributions/custom.html" rel="noopener noreferrer"&gt;PyMC documentation on custom distributions&lt;/a&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Previous in this series:&lt;/strong&gt; Start with &lt;a href="https://sesen.ai/blog/hierarchical-bayesian-regression-pymc" rel="noopener noreferrer"&gt;Hierarchical Bayesian Regression&lt;/a&gt;, then &lt;a href="https://sesen.ai/blog/bayesian-survival-analysis-pymc" rel="noopener noreferrer"&gt;Bayesian Survival Analysis&lt;/a&gt; for the progression from built-in distributions to custom likelihoods.&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Interactive Tools
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/distribution-explorer" rel="noopener noreferrer"&gt;Distribution Explorer&lt;/a&gt; — Visualise the Beta distribution and other families used in this model&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/bayes-theorem-calculator" rel="noopener noreferrer"&gt;Bayes' Theorem Calculator&lt;/a&gt; — Explore Bayesian reasoning interactively&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Related Posts
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/hierarchical-bayesian-regression-pymc" rel="noopener noreferrer"&gt;Hierarchical Bayesian Regression with PyMC: When Groups Share Strength&lt;/a&gt; — Partial pooling and group-level priors in PyMC&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/bayesian-survival-analysis-pymc" rel="noopener noreferrer"&gt;Bayesian Survival Analysis with PyMC: Modelling Customer Churn&lt;/a&gt; — Another custom likelihood built in PyMC&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/from-mle-to-bayesian-inference" rel="noopener noreferrer"&gt;From MLE to Bayesian Inference&lt;/a&gt; — Why we use priors and posteriors instead of point estimates&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/mcmc-metropolis-hastings-island-hopping-guide" rel="noopener noreferrer"&gt;MCMC Metropolis-Hastings: An Island-Hopping Guide&lt;/a&gt; — The sampling engine behind PyMC&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Frequently Asked Questions
&lt;/h2&gt;

&lt;h3&gt;
  
  
  When should I use a One-Inflated Beta model instead of logistic regression?
&lt;/h3&gt;

&lt;p&gt;Use OIB when your outcome is a fraction between 0 and 1 with a spike at the boundary value of 1. Logistic regression discards the partial repayment information by collapsing everything into a binary label. OIB preserves both the probability of full repayment and the distribution of partial repayments, giving you richer predictions.&lt;/p&gt;

&lt;h3&gt;
  
  
  Why use pm.Potential instead of pm.CustomDist for the likelihood?
&lt;/h3&gt;

&lt;p&gt;The pm.CustomDist approach evaluates both branches of the piecewise likelihood for every observation during automatic differentiation. When the Beta density is evaluated at the boundary value of 1.0, it returns negative infinity, which corrupts the NUTS sampler gradients and causes 100% divergences. Splitting observations with pm.Potential avoids evaluating the Beta density at the boundary entirely.&lt;/p&gt;

&lt;h3&gt;
  
  
  What is the mean-precision parameterisation of the Beta distribution?
&lt;/h3&gt;

&lt;p&gt;Instead of the standard shape parameters alpha and beta, the mean-precision form uses mu (the mean, between 0 and 1) and phi (the precision, controlling spread). This is more interpretable: mu directly tells you the expected partial repayment fraction, while phi tells you how concentrated the distribution is around that mean. The standard parameters are recovered as alpha = mu * phi and beta = (1 - mu) * phi.&lt;/p&gt;

&lt;h3&gt;
  
  
  How do I check whether the OIB model fits my data well?
&lt;/h3&gt;

&lt;p&gt;Generate posterior predictive samples by drawing from the fitted model and comparing the resulting distribution to the observed data. The key check is whether the model reproduces both the spike at 1.0 (the proportion of fully repaid loans) and the shape of the continuous partial repayment distribution. If either component is mismatched, the model needs adjustment.&lt;/p&gt;

&lt;h3&gt;
  
  
  Can this model handle spikes at both 0 and 1?
&lt;/h3&gt;

&lt;p&gt;Yes, but you would need a Zero-and-One-Inflated Beta (ZOIB) model. This adds a third mixture component for the spike at zero, with its own probability parameter. The piecewise likelihood gains a third branch, but the pm.Potential implementation pattern remains the same: split observations into three groups and add each group's log-likelihood separately.&lt;/p&gt;

</description>
      <category>bayesian</category>
      <category>probabilistic</category>
      <category>pymc</category>
      <category>customlikelihood</category>
    </item>
    <item>
      <title>Bayesian Survival Analysis with PyMC: Modelling Customer Churn</title>
      <dc:creator>Berkan Sesen</dc:creator>
      <pubDate>Wed, 29 Apr 2026 09:53:05 +0000</pubDate>
      <link>https://dev.to/berkan_sesen/bayesian-survival-analysis-with-pymc-modelling-customer-churn-55n4</link>
      <guid>https://dev.to/berkan_sesen/bayesian-survival-analysis-with-pymc-modelling-customer-churn-55n4</guid>
      <description>&lt;p&gt;Every subscription business lives or dies by churn. Whether it is a B2B SaaS platform tracking annual contracts or a consumer app watching monthly renewals, the question is the same: how long will this customer stay? The data seems straightforward. Some subscribers cancelled after a month, others after a year. But a large share of customers are still active. They have not churned yet, and you do not know when, or whether, they will.&lt;/p&gt;

&lt;p&gt;A colleague suggested dropping them from the analysis. That felt wrong, and it is: ignoring active customers biases your model toward shorter lifetimes, because you only learn from the people who already left.&lt;/p&gt;

&lt;p&gt;The problem has a name: &lt;strong&gt;right-censoring&lt;/strong&gt;. An active customer who signed up 8 months ago tells you something valuable: they survived &lt;em&gt;at least&lt;/em&gt; 8 months. You don't know when (or whether) they'll churn, but that lower bound is real information.&lt;/p&gt;

&lt;p&gt;Survival analysis handles censoring properly. In our &lt;a href="https://sesen.ai/blog/hierarchical-bayesian-regression-pymc" rel="noopener noreferrer"&gt;previous post&lt;/a&gt;, we built hierarchical models in PyMC for grouped regression. This post extends that toolkit with a new ingredient: the ability to learn from incomplete observations.&lt;/p&gt;

&lt;p&gt;By the end, you'll build a Bayesian accelerated failure time (AFT) model in PyMC, handle right-censored data with &lt;code&gt;pm.Potential&lt;/code&gt;, compare Weibull and Log-Logistic distributions, and plot individual survival curves for different customer profiles.&lt;/p&gt;

&lt;h2&gt;
  
  
  Let's Build It
&lt;/h2&gt;

&lt;p&gt;First, let's see the model in action. Click the badge below to open the full interactive notebook:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://colab.research.google.com/github/zhubarb/sesen_ai_ml_tutorials/blob/main/notebooks/bayesian/bayesian_survival_analysis.ipynb" rel="noopener noreferrer"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fcolab.research.google.com%2Fassets%2Fcolab-badge.svg" alt="Open In Colab" width="117" height="20"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;We'll generate synthetic churn data for 1,000 customers, fit a Weibull AFT model, and plot survival curves.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fx3rnp8lwbc8rg52yidut.gif" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fx3rnp8lwbc8rg52yidut.gif" alt="Survival curves for three customer profiles building up as MCMC samples accumulate. Early frames show scattered, uncertain curves; later frames converge to smooth, separated survival functions for high-value, average, and at-risk customers." width="800" height="500"&gt;&lt;/a&gt;&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;pandas&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;pd&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;pymc&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;pytensor.tensor&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;pt&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;arviz&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;az&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;matplotlib.pyplot&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;

&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;seed&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Generate synthetic churn data: 1,000 customers observed over 24 months
&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;1000&lt;/span&gt;
&lt;span class="n"&gt;monthly_spend&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;30&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;clip&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;20&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;250&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;support_tickets&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;poisson&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;astype&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;float&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Standardise covariates
&lt;/span&gt;&lt;span class="n"&gt;spend_std&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;monthly_spend&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="mi"&gt;30&lt;/span&gt;
&lt;span class="n"&gt;tickets_std&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;support_tickets&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;

&lt;span class="c1"&gt;# True AFT parameters (Gumbel / log-Weibull parameterisation)
&lt;/span&gt;&lt;span class="n"&gt;true_alpha&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;array&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="mf"&gt;2.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.3&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;  &lt;span class="c1"&gt;# intercept, spend, tickets
&lt;/span&gt;&lt;span class="n"&gt;true_s&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;0.6&lt;/span&gt;

&lt;span class="c1"&gt;# True log-time: Y = eta + s * W, where W ~ Gumbel(0,1)
&lt;/span&gt;&lt;span class="n"&gt;eta_true&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;true_alpha&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;true_alpha&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;spend_std&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;true_alpha&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;tickets_std&lt;/span&gt;
&lt;span class="n"&gt;log_time_true&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;eta_true&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;true_s&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;gumbel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;time_true&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_time_true&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Administrative censoring at 24 months
&lt;/span&gt;&lt;span class="n"&gt;observation_window&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;24.0&lt;/span&gt;
&lt;span class="n"&gt;observed_time&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;minimum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;time_true&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;observation_window&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;censored&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;time_true&lt;/span&gt; &lt;span class="o"&gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;observation_window&lt;/span&gt;  &lt;span class="c1"&gt;# True = still active
&lt;/span&gt;&lt;span class="n"&gt;log_observed_time&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;observed_time&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Total customers: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Churned: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;~&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; (&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;~&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="o"&gt;%&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;)&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Still active (censored): &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; (&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="o"&gt;%&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;)&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;Total customers: 1000
Churned: 664 (66%)
Still active (censored): 336 (34%)
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Before fitting the Bayesian model, let's look at the empirical survival curve using the Kaplan-Meier estimator. This non-parametric method handles censoring correctly by adjusting the risk set at each event time:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Kaplan-Meier estimator (manual, no extra dependencies)
&lt;/span&gt;&lt;span class="n"&gt;order&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;argsort&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;observed_time&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;times_sorted&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;observed_time&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;order&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;span class="n"&gt;events_sorted&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;~&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;)[&lt;/span&gt;&lt;span class="n"&gt;order&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;astype&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;km_times&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mf"&gt;0.0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;span class="n"&gt;km_survival&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;span class="n"&gt;n_at_risk&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;N&lt;/span&gt;

&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;event&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;zip&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;times_sorted&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;events_sorted&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;event&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="n"&gt;km_survival&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;km_survival&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;n_at_risk&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
        &lt;span class="n"&gt;km_times&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;n_at_risk&lt;/span&gt; &lt;span class="o"&gt;-=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;

&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;subplots&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;km_times&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;km_survival&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;where&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;post&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;#2196F3&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;lw&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_xlabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Months since signup&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_ylabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Survival probability&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_title&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Kaplan-Meier Survival Curve&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_xlim&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;25&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_ylim&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;1.05&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tight_layout&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fca6xhp3ao8ssvpctzm0a.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fca6xhp3ao8ssvpctzm0a.webp" alt="Kaplan-Meier survival curve for the synthetic churn data. The curve drops steadily over 24 months with censoring tick marks visible. About 34% of customers survive past the 24-month observation window." width="800" height="451"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Now let's fit the Weibull AFT model. The key insight: if &lt;code&gt;$T \sim \text{Weibull}$&lt;/code&gt;, then &lt;code&gt;$Y = \log T$&lt;/code&gt; follows a Gumbel distribution. So we model log-time with a Gumbel likelihood, which lets us write the linear predictor naturally:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;gumbel_log_sf&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;Log survival function of the Gumbel distribution.&lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log1p&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="p"&gt;)))&lt;/span&gt;

&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Model&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;weibull_aft&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="c1"&gt;# Location coefficients (priors match original code: Normal(0, 2))
&lt;/span&gt;    &lt;span class="n"&gt;alpha&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;alpha&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Scale parameter (must be positive)
&lt;/span&gt;    &lt;span class="n"&gt;log_s&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;log_s&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;s&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Deterministic&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;s&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_s&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

    &lt;span class="c1"&gt;# Linear predictor for log-time
&lt;/span&gt;    &lt;span class="n"&gt;eta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;spend_std&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;tickets_std&lt;/span&gt;

    &lt;span class="c1"&gt;# Uncensored customers: standard Gumbel likelihood
&lt;/span&gt;    &lt;span class="n"&gt;y_obs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Gumbel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;y_obs&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;eta&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;~&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                       &lt;span class="n"&gt;observed&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;log_observed_time&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;~&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;

    &lt;span class="c1"&gt;# Censored customers: survival function via pm.Potential
&lt;/span&gt;    &lt;span class="n"&gt;y_cens&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Potential&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;y_cens&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="nf"&gt;gumbel_log_sf&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_observed_time&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;eta&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

    &lt;span class="c1"&gt;# Sample the posterior
&lt;/span&gt;    &lt;span class="n"&gt;trace&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tune&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cores&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;chains&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                      &lt;span class="n"&gt;random_seed&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;target_accept&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.9&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;az&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;summary&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;trace&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;var_names&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;alpha&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;s&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;You just fit a Bayesian survival model that properly handles censored customers. The &lt;code&gt;alpha&lt;/code&gt; coefficients tell you how each covariate affects time-to-churn: positive means longer survival, negative means faster churn. And unlike a point estimate from &lt;a href="https://sesen.ai/blog/maximum-likelihood-estimation-from-scratch" rel="noopener noreferrer"&gt;maximum likelihood&lt;/a&gt;, you get full posterior distributions over every parameter.&lt;/p&gt;

&lt;h2&gt;
  
  
  What Just Happened?
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Right-Censoring: Learning from Incomplete Data
&lt;/h3&gt;

&lt;p&gt;The 336 active customers in our data didn't churn during the 24-month observation window. For each one, we know they survived &lt;em&gt;at least&lt;/em&gt; 24 months, but not how much longer they'll stay. This is &lt;strong&gt;right-censoring&lt;/strong&gt;: the true event time is somewhere to the right of what we observed.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fyjqswh6ctfl4cr46c73q.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fyjqswh6ctfl4cr46c73q.webp" alt="Timeline diagram showing 8 example customers. Five lines end with a red X (churn event) at various times. Three lines extend to the 24-month boundary and end with a green arrow (still active, censored). The observation window is shaded." width="800" height="412"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Standard regression would force you to either drop censored customers (biasing estimates downward) or code them as churning at 24 months (also biased). Survival analysis treats the two types of observation differently in the likelihood.&lt;/p&gt;

&lt;p&gt;For a churned customer at time &lt;code&gt;$t_i$&lt;/code&gt;, the likelihood contribution is the probability density &lt;code&gt;$f(t_i)$&lt;/code&gt;: we observed this exact event time. For a censored customer observed until time &lt;code&gt;$c_i$&lt;/code&gt;, the contribution is the survival probability &lt;code&gt;$S(c_i) = P(T &amp;gt; c_i)$&lt;/code&gt;: all we know is they lasted at least this long.&lt;/p&gt;

&lt;p&gt;The total log-likelihood combines both pieces:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Cell%28%255Ctheta%29%2520%253D%2520%255Csum_%257Bi%2520%255Cin%2520%255Ctext%257Buncensored%257D%257D%2520%255Clog%2520f%28t_i%2520%255Cmid%2520%255Ctheta%29%2520%255C%253B%252B%255C%253B%2520%255Csum_%257Bi%2520%255Cin%2520%255Ctext%257Bcensored%257D%257D%2520%255Clog%2520S%28c_i%2520%255Cmid%2520%255Ctheta%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Cell%28%255Ctheta%29%2520%253D%2520%255Csum_%257Bi%2520%255Cin%2520%255Ctext%257Buncensored%257D%257D%2520%255Clog%2520f%28t_i%2520%255Cmid%2520%255Ctheta%29%2520%255C%253B%252B%255C%253B%2520%255Csum_%257Bi%2520%255Cin%2520%255Ctext%257Bcensored%257D%257D%2520%255Clog%2520S%28c_i%2520%255Cmid%2520%255Ctheta%29" alt="equation" width="550" height="54"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;This is exactly how our PyMC model works. The &lt;code&gt;pm.Gumbel&lt;/code&gt; line handles the first sum (uncensored density). The &lt;code&gt;pm.Potential&lt;/code&gt; line handles the second sum (censored survival).&lt;/p&gt;

&lt;h3&gt;
  
  
  Why Gumbel? The Weibull-Gumbel Connection
&lt;/h3&gt;

&lt;p&gt;The Weibull distribution is the workhorse of survival analysis because it models flexible hazard rates: increasing, decreasing, or constant over time. But working with the Weibull directly is numerically awkward for regression.&lt;/p&gt;

&lt;p&gt;Here's the trick. If &lt;code&gt;$T \sim \text{Weibull}(k, \lambda)$&lt;/code&gt;, then &lt;code&gt;$Y = \log T$&lt;/code&gt; follows a Gumbel distribution:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DY%2520%253D%2520%255Clog%2520T%2520%253D%2520%255Cmu%2520%252B%2520%255Csigma%2520%255Ccdot%2520W%252C%2520%255Cquad%2520W%2520%255Csim%2520%255Ctext%257BGumbel%257D%280%252C%25201%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DY%2520%253D%2520%255Clog%2520T%2520%253D%2520%255Cmu%2520%252B%2520%255Csigma%2520%255Ccdot%2520W%252C%2520%255Cquad%2520W%2520%255Csim%2520%255Ctext%257BGumbel%257D%280%252C%25201%29" alt="equation" width="470" height="25"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;where &lt;code&gt;$\mu = \log \lambda$&lt;/code&gt; is the location and &lt;code&gt;$\sigma = 1/k$&lt;/code&gt; is the scale. This is the &lt;strong&gt;accelerated failure time&lt;/strong&gt; (AFT) formulation: covariates shift &lt;code&gt;$\mu$&lt;/code&gt;, effectively accelerating or decelerating time. We write the linear predictor as:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Ceta_i%2520%253D%2520%255Calpha_0%2520%252B%2520%255Calpha_1%2520%255Ccdot%2520%255Ctext%257Bspend%257D_i%2520%252B%2520%255Calpha_2%2520%255Ccdot%2520%255Ctext%257Btickets%257D_i" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Ceta_i%2520%253D%2520%255Calpha_0%2520%252B%2520%255Calpha_1%2520%255Ccdot%2520%255Ctext%257Bspend%257D_i%2520%252B%2520%255Calpha_2%2520%255Ccdot%2520%255Ctext%257Btickets%257D_i" alt="equation" width="368" height="23"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;A positive &lt;code&gt;$\alpha_1$&lt;/code&gt; means higher spending shifts log-time to the right (longer survival). A negative &lt;code&gt;$\alpha_2$&lt;/code&gt; means more support tickets shift it left (faster churn). The coefficients have a direct interpretation: a one-unit increase in &lt;code&gt;$x_j$&lt;/code&gt; multiplies the median survival time by &lt;code&gt;$\exp(\alpha_j)$&lt;/code&gt;.&lt;/p&gt;

&lt;h3&gt;
  
  
  &lt;code&gt;pm.Potential&lt;/code&gt;: Telling PyMC About Partial Information
&lt;/h3&gt;

&lt;p&gt;In our &lt;a href="https://sesen.ai/blog/hierarchical-bayesian-regression-pymc" rel="noopener noreferrer"&gt;hierarchical regression post&lt;/a&gt;, every observation contributed a full likelihood term through &lt;code&gt;pm.Normal(..., observed=y)&lt;/code&gt;. Censored observations are different: they don't have a fully observed outcome. They only contribute through the survival function.&lt;/p&gt;

&lt;p&gt;&lt;code&gt;pm.Potential('name', value)&lt;/code&gt; adds &lt;code&gt;value&lt;/code&gt; directly to the model's log-posterior. For censored data, we pass the log-survival probability:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;y_cens&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Potential&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;y_cens&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="nf"&gt;gumbel_log_sf&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_observed_time&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;eta&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Think of it this way. For a churned customer, we say "we observed them leave at time &lt;code&gt;$t$&lt;/code&gt;" (standard likelihood). For an active customer, we say "all we know is they're still here after &lt;code&gt;$c$&lt;/code&gt; months" (survival function).&lt;/p&gt;

&lt;h3&gt;
  
  
  MCMC Diagnostics
&lt;/h3&gt;

&lt;p&gt;Before trusting the results, verify the sampler converged:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;az&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot_trace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;trace&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;var_names&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;alpha&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;s&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tight_layout&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fsfgg84yj6ev5dueud1cb.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fsfgg84yj6ev5dueud1cb.webp" alt="ArviZ trace plots for the Weibull AFT model. Top row: alpha posteriors (intercept near 2.5, spend coefficient near 0.4, tickets coefficient near −0.3) with MCMC traces. Bottom row: scale parameter s centred near 0.63. All four chains mix well with stable, overlapping traces." width="800" height="278"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Check the same three diagnostics we covered in the &lt;a href="https://sesen.ai/blog/hierarchical-bayesian-regression-pymc" rel="noopener noreferrer"&gt;hierarchical regression post&lt;/a&gt;: chains should look like "hairy caterpillars" (good mixing), R-hat below 1.01 (convergence), and effective sample size above 400 per chain (low autocorrelation).&lt;/p&gt;

&lt;h3&gt;
  
  
  Survival Curves from the Posterior
&lt;/h3&gt;

&lt;p&gt;The payoff of a Bayesian AFT model is individual survival curves with uncertainty bands. For any customer profile, we compute the survival probability at each time point across all posterior samples:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;t_grid&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;linspace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;36&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;200&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;log_t_grid&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;t_grid&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Extract posterior samples
&lt;/span&gt;&lt;span class="n"&gt;alpha_post&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;trace&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;posterior&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;alpha&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;s_post&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;trace&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;posterior&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;s&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;flatten&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

&lt;span class="n"&gt;profiles&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;High-value (spend +1.5σ, tickets −1σ)&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;1.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;#2196F3&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Average customer&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;                        &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;0.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;  &lt;span class="mf"&gt;0.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;#FF9800&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;At-risk (spend −1.5σ, tickets +2σ)&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;     &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;1.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;  &lt;span class="mf"&gt;2.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;#F44336&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;span class="p"&gt;}&lt;/span&gt;

&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;subplots&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;sp&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tk&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;profiles&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;items&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
    &lt;span class="n"&gt;eta_post&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;alpha_post&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;alpha_post&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;sp&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;alpha_post&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;tk&lt;/span&gt;
    &lt;span class="n"&gt;survival&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;eta_post&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;t_grid&lt;/span&gt;&lt;span class="p"&gt;)))&lt;/span&gt;
    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;eta_post&lt;/span&gt;&lt;span class="p"&gt;)):&lt;/span&gt;
        &lt;span class="n"&gt;z&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_t_grid&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;eta_post&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;s_post&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
        &lt;span class="n"&gt;survival&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;z&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
    &lt;span class="n"&gt;mean_surv&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;survival&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;lower&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;percentile&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;survival&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;upper&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;percentile&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;survival&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;97&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;t_grid&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mean_surv&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;lw&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fill_between&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;t_grid&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;lower&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;upper&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.15&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_xlabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Months since signup&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_ylabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Survival probability&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_title&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Predicted Survival Curves by Customer Profile&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;legend&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;loc&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;upper right&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;fontsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;9&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_xlim&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;36&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_ylim&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;1.05&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tight_layout&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Ffghc87rg33vegsxfkjrm.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Ffghc87rg33vegsxfkjrm.webp" alt="Survival curves for three customer profiles with 94% HDI bands. The high-value customer (blue) stays above 85% survival at 24 months. The average customer (orange) crosses 50% around 13 months. The at-risk customer (red) drops below 20% by month 10." width="800" height="545"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Each curve shows the model's predicted probability that a customer with those characteristics survives beyond a given time. The high-value customer has a much flatter curve: their predicted median lifetime exceeds 36 months. The at-risk customer (low spend, many support tickets) has a steep drop-off with a median around 5 months.&lt;/p&gt;

&lt;p&gt;Notice the uncertainty bands widen at longer times, especially for the at-risk profile. Fewer customers with those characteristics survive that long, so the model has less data to constrain the prediction.&lt;/p&gt;

&lt;h2&gt;
  
  
  Going Deeper
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Covariates in the Scale Too
&lt;/h3&gt;

&lt;p&gt;The model above uses a constant scale parameter &lt;code&gt;$s$&lt;/code&gt; for all customers. The original code I adapted goes further by making the scale covariate-dependent:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Ds_i%2520%253D%2520%255Cexp%255C%21%255Cleft%28%255Crho_0%2520%252B%2520%255Crho_1%2520%255Ccdot%2520x_%257Bi1%257D%2520%252B%2520%255Crho_2%2520%255Ccdot%2520x_%257Bi2%257D%255Cright%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Ds_i%2520%253D%2520%255Cexp%255C%21%255Cleft%28%255Crho_0%2520%252B%2520%255Crho_1%2520%255Ccdot%2520x_%257Bi1%257D%2520%252B%2520%255Crho_2%2520%255Ccdot%2520x_%257Bi2%257D%255Cright%29" alt="equation" width="328" height="25"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;This means the &lt;em&gt;shape&lt;/em&gt; of the Weibull hazard varies across customers. A customer might have both a longer expected lifetime (larger &lt;code&gt;$\eta$&lt;/code&gt;) and more predictable survival (smaller &lt;code&gt;$s$&lt;/code&gt;). In PyMC:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Model&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;weibull_aft_hetero&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="c1"&gt;# Location coefficients
&lt;/span&gt;    &lt;span class="n"&gt;alpha&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;alpha&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="c1"&gt;# Scale coefficients (matching original code's rho priors)
&lt;/span&gt;    &lt;span class="n"&gt;rho&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;rho&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="n"&gt;eta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;spend_std&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;tickets_std&lt;/span&gt;
    &lt;span class="n"&gt;s&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;rho&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;rho&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;spend_std&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;rho&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;tickets_std&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="n"&gt;y_obs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Gumbel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;y_obs&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;eta&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;~&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;~&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
                       &lt;span class="n"&gt;observed&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;log_observed_time&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;~&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
    &lt;span class="n"&gt;y_cens&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Potential&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;y_cens&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="nf"&gt;gumbel_log_sf&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_observed_time&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;eta&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;]))&lt;/span&gt;

    &lt;span class="n"&gt;trace_hetero&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tune&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cores&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;chains&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                             &lt;span class="n"&gt;random_seed&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;target_accept&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.9&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This is faithful to the &lt;code&gt;aft_model_factory_explicit&lt;/code&gt; function in the original code, which uses separate &lt;code&gt;rho_interc&lt;/code&gt;, &lt;code&gt;rho_coeff1&lt;/code&gt;, &lt;code&gt;rho_coeff2&lt;/code&gt; parameters for the Gumbel scale. The &lt;code&gt;exp&lt;/code&gt; link ensures &lt;code&gt;$s_i &amp;gt; 0$&lt;/code&gt; for every customer.&lt;/p&gt;

&lt;h3&gt;
  
  
  Weibull vs Log-Logistic: Which Tail Shape?
&lt;/h3&gt;

&lt;p&gt;The Weibull model assumes the hazard rate is &lt;em&gt;monotonic&lt;/em&gt;: always increasing, always decreasing, or constant. But some churn patterns are non-monotonic. New users might have high churn risk initially (they haven't found value yet), which drops as they engage, then rises again as they outgrow the product.&lt;/p&gt;

&lt;p&gt;The &lt;strong&gt;Log-Logistic&lt;/strong&gt; AFT model handles this. In log-time, the Log-Logistic corresponds to a Logistic distribution, just as the Weibull corresponds to a Gumbel. The swap is straightforward:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;logistic_log_sf&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;Log survival function of the Logistic distribution.&lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;softplus&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Model&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;loglogistic_aft&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;alpha&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;alpha&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;log_s&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;log_s&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;s&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Deterministic&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;s&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_s&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

    &lt;span class="n"&gt;eta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;spend_std&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;tickets_std&lt;/span&gt;

    &lt;span class="n"&gt;y_obs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Logistic&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;y_obs&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;eta&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;~&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                         &lt;span class="n"&gt;observed&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;log_observed_time&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;~&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
    &lt;span class="n"&gt;y_cens&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Potential&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;y_cens&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="nf"&gt;logistic_log_sf&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_observed_time&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;eta&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

    &lt;span class="n"&gt;trace_ll&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tune&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cores&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;chains&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                         &lt;span class="n"&gt;random_seed&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;target_accept&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.9&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Feplgs0eoeu46xxyayhj2.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Feplgs0eoeu46xxyayhj2.webp" alt="Two-panel comparison of Weibull (left) and Log-Logistic (right) survival curves for the average customer. The Weibull curve decays smoothly following a stretched exponential. The Log-Logistic curve has a heavier tail, decaying more slowly at longer times." width="800" height="345"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Compare the two models using LOO-CV (leave-one-out cross-validation) with ArviZ:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;weibull_loo&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;az&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;loo&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;trace&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ll_loo&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;az&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;loo&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;trace_ll&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;az&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;compare&lt;/span&gt;&lt;span class="p"&gt;({&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Weibull&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;trace&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Log-Logistic&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;trace_ll&lt;/span&gt;&lt;span class="p"&gt;}))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Since our synthetic data was generated from a Weibull distribution, the Weibull model should win. On real data, the comparison often reveals which tail shape better captures your customers' churn dynamics.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Cox Proportional Hazards Alternative
&lt;/h3&gt;

&lt;p&gt;Survival analysis has a dominant semi-parametric approach: the Cox proportional hazards (PH) model. It doesn't assume a distribution for the baseline hazard, only that covariates multiply the hazard by a constant factor. This flexibility made it ubiquitous in clinical trials.&lt;/p&gt;

&lt;p&gt;So why choose a parametric Bayesian AFT model? Three reasons:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Full predictive distributions.&lt;/strong&gt; The Cox model gives hazard ratios, but producing survival curves requires additional estimation of the baseline hazard. Our Bayesian AFT model gives survival curves with uncertainty bands directly from the posterior.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Small samples and heavy censoring.&lt;/strong&gt; With many active customers, the Cox model's partial likelihood can be imprecise. Bayesian priors stabilise estimates, especially for rare covariates. This is the same principle of "borrowing strength" we explored in the &lt;a href="https://sesen.ai/blog/hierarchical-bayesian-regression-pymc" rel="noopener noreferrer"&gt;hierarchical regression post&lt;/a&gt;.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Natural extension.&lt;/strong&gt; PyMC models compose freely. Adding group structure (churn by subscription tier), time-varying covariates, or custom likelihoods is straightforward. The next post in this series demonstrates exactly this with a one-inflated Beta regression.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fu8wt5fam678v77rhkb69.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fu8wt5fam678v77rhkb69.webp" alt="Flow diagram showing the AFT model structure. Covariates (monthly spend, support tickets) feed into two linear predictors: one for the location parameter eta and one for the scale parameter s. These combine into a Gumbel distribution for log-time, which maps to a Weibull distribution for actual survival time. Censored and uncensored paths split at the likelihood." width="800" height="446"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3&gt;
  
  
  When NOT to Use Bayesian AFT
&lt;/h3&gt;

&lt;p&gt;If the proportional hazards assumption holds and your dataset is large (tens of thousands of events), the Cox model is faster and assumption-lighter. If you have time-varying covariates that change during a customer's lifetime (e.g., monthly usage patterns), the standard AFT formulation doesn't handle them naturally; you'd need a piecewise approach or a joint model.&lt;/p&gt;

&lt;p&gt;Computational cost matters too. Our 1,000-customer model samples in a few minutes, but production datasets with millions of rows would require approximations like variational inference or mini-batch MCMC.&lt;/p&gt;

&lt;h2&gt;
  
  
  Where This Comes From
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Cox (1972): Proportional Hazards
&lt;/h3&gt;

&lt;p&gt;The modern era of survival analysis began with David Cox's 1972 paper "Regression Models and Life-Tables." Cox introduced the proportional hazards model:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Dh%28t%2520%255Cmid%2520%255Cmathbf%257Bx%257D%29%2520%253D%2520h_0%28t%29%2520%255Cexp%28%255Cboldsymbol%257B%255Cbeta%257D%255E%255Ctop%2520%255Cmathbf%257Bx%257D%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Dh%28t%2520%255Cmid%2520%255Cmathbf%257Bx%257D%29%2520%253D%2520h_0%28t%29%2520%255Cexp%28%255Cboldsymbol%257B%255Cbeta%257D%255E%255Ctop%2520%255Cmathbf%257Bx%257D%29" alt="equation" width="267" height="29"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;where &lt;code&gt;$h_0(t)$&lt;/code&gt; is an unspecified baseline hazard. The genius was leaving &lt;code&gt;$h_0$&lt;/code&gt; unspecified and estimating &lt;code&gt;$\boldsymbol{\beta}$&lt;/code&gt; through the &lt;strong&gt;partial likelihood&lt;/strong&gt;, which depends only on the order of events, not their exact times. This paper has been cited over 65,000 times and remains the most-used method in clinical trials.&lt;/p&gt;

&lt;blockquote&gt;
&lt;p&gt;"The important practical point is that [the partial likelihood] does not require specification of &lt;code&gt;$h_0(t)$&lt;/code&gt;." (Cox, 1972)&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;Our AFT model takes a different path: we specify a distribution (Weibull or Log-Logistic), which enables direct time predictions. This parametric assumption is both a strength (more powerful inference when correct) and a weakness (biased inference when wrong).&lt;/p&gt;

&lt;h3&gt;
  
  
  Buckley and James (1979): Accelerated Failure Time
&lt;/h3&gt;

&lt;p&gt;The AFT framework was formalised by Miles Buckley and Ian James in 1979. Their key insight was that the AFT model has a direct linear regression interpretation:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Clog%2520T_i%2520%253D%2520%255Cmathbf%257Bx%257D_i%255E%255Ctop%2520%255Cboldsymbol%257B%255Calpha%257D%2520%252B%2520%255Csigma%2520%255Cepsilon_i" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Clog%2520T_i%2520%253D%2520%255Cmathbf%257Bx%257D_i%255E%255Ctop%2520%255Cboldsymbol%257B%255Calpha%257D%2520%252B%2520%255Csigma%2520%255Cepsilon_i" alt="equation" width="199" height="28"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;where &lt;code&gt;$\epsilon_i$&lt;/code&gt; follows a known distribution (Gumbel for Weibull, Logistic for Log-Logistic). The coefficients &lt;code&gt;$\alpha_j$&lt;/code&gt; have a clean meaning: a one-unit increase in &lt;code&gt;$x_j$&lt;/code&gt; multiplies the median survival time by &lt;code&gt;$\exp(\alpha_j)$&lt;/code&gt;. This is why it's called "accelerated failure time": covariates speed up or slow down the passage of time.&lt;/p&gt;

&lt;h3&gt;
  
  
  Wei (1992): AFT as an Alternative
&lt;/h3&gt;

&lt;p&gt;L. J. Wei's 1992 paper "The Accelerated Failure Time Model: A Useful Alternative to the Cox Regression Model in Survival Analysis" made the case for AFT models as a practical complement to Cox PH. Wei showed that AFT models are more robust to omitted covariates and provide more interpretable effect sizes.&lt;/p&gt;

&lt;blockquote&gt;
&lt;p&gt;"When the acceleration factor is constant over time, the AFT model provides a simple and clinically meaningful summary of the survival experience." (Wei, 1992)&lt;/p&gt;
&lt;/blockquote&gt;

&lt;h3&gt;
  
  
  Handling Censoring in PyMC
&lt;/h3&gt;

&lt;p&gt;The &lt;code&gt;pm.Potential&lt;/code&gt; approach for censored data follows directly from the likelihood factorisation. For a dataset with observed and censored outcomes:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DL%28%255Ctheta%29%2520%253D%2520%255Cprod_%257Bi%2520%255Cin%2520%255Ctext%257Buncensored%257D%257D%2520f%28t_i%2520%255Cmid%2520%255Ctheta%29%2520%255C%253B%255Cprod_%257Bi%2520%255Cin%2520%255Ctext%257Bcensored%257D%257D%2520S%28c_i%2520%255Cmid%2520%255Ctheta%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DL%28%255Ctheta%29%2520%253D%2520%255Cprod_%257Bi%2520%255Cin%2520%255Ctext%257Buncensored%257D%257D%2520f%28t_i%2520%255Cmid%2520%255Ctheta%29%2520%255C%253B%255Cprod_%257Bi%2520%255Cin%2520%255Ctext%257Bcensored%257D%257D%2520S%28c_i%2520%255Cmid%2520%255Ctheta%29" alt="equation" width="452" height="54"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Taking logs, the uncensored terms give the standard log-likelihood (handled by &lt;code&gt;pm.Gumbel&lt;/code&gt; or &lt;code&gt;pm.Logistic&lt;/code&gt;). The censored terms give log-survival values (handled by &lt;code&gt;pm.Potential&lt;/code&gt;). This pattern appears throughout the &lt;a href="https://www.pymc.io/projects/examples/en/latest/survival_analysis/weibull_aft.html" rel="noopener noreferrer"&gt;PyMC survival analysis examples&lt;/a&gt; and extends naturally to interval censoring and left censoring by swapping the survival function for the appropriate probability term.&lt;/p&gt;

&lt;h3&gt;
  
  
  Further Reading
&lt;/h3&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;The proportional hazards model:&lt;/strong&gt; Cox, D. R. (1972). "Regression Models and Life-Tables." &lt;em&gt;Journal of the Royal Statistical Society: Series B&lt;/em&gt;, 34(2), 187-220.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;The AFT framework:&lt;/strong&gt; Buckley, J. &amp;amp; James, I. (1979). "Linear Regression with Censored Data." &lt;em&gt;Biometrika&lt;/em&gt;, 66(3), 429-436.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;AFT as a Cox alternative:&lt;/strong&gt; Wei, L. J. (1992). "The Accelerated Failure Time Model." &lt;em&gt;Statistics in Medicine&lt;/em&gt;, 11(14-15), 1871-1879.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;The standard reference:&lt;/strong&gt; Kalbfleisch, J. D. &amp;amp; Prentice, R. L. (2002). &lt;em&gt;The Statistical Analysis of Failure Time Data&lt;/em&gt;, 2nd ed. Wiley.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;PyMC survival example:&lt;/strong&gt; &lt;a href="https://www.pymc.io/projects/examples/en/latest/survival_analysis/weibull_aft.html" rel="noopener noreferrer"&gt;Weibull AFT notebook&lt;/a&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Previous in this series:&lt;/strong&gt; &lt;a href="https://sesen.ai/blog/hierarchical-bayesian-regression-pymc" rel="noopener noreferrer"&gt;Hierarchical Bayesian Regression with PyMC&lt;/a&gt;, which introduces PyMC, partial pooling, and ArviZ diagnostics.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Next in this series:&lt;/strong&gt; Custom likelihoods in PyMC, where we build a one-inflated Beta regression for bounded outcome data.&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Interactive Tools
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/kaplan-meier-calculator" rel="noopener noreferrer"&gt;Kaplan-Meier Calculator&lt;/a&gt; — Estimate survival curves and compare groups interactively&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/medical-stats-calculator" rel="noopener noreferrer"&gt;Medical Statistics Calculator&lt;/a&gt; — Compute sensitivity, specificity, and other diagnostic metrics&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Related Posts
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/hierarchical-bayesian-regression-pymc" rel="noopener noreferrer"&gt;Hierarchical Bayesian Regression with PyMC&lt;/a&gt;: The first post in this PyMC series, covering partial pooling and MCMC diagnostics with ArviZ.&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/mcmc-metropolis-hastings-island-hopping-guide" rel="noopener noreferrer"&gt;MCMC Island Hopping: Understanding Metropolis-Hastings&lt;/a&gt;: How the NUTS sampler that powers PyMC explores high-dimensional posteriors.&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/from-mle-to-bayesian-inference" rel="noopener noreferrer"&gt;From MLE to Bayesian Inference&lt;/a&gt;: The conceptual foundation for priors, posteriors, and why Bayesian estimates outperform point estimates.&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Frequently Asked Questions
&lt;/h2&gt;

&lt;h3&gt;
  
  
  What is right-censoring and why does it matter?
&lt;/h3&gt;

&lt;p&gt;Right-censoring occurs when you know a subject survived at least until a certain time, but not the actual event time. In churn analysis, active customers are right-censored because they have not yet churned. Ignoring them biases your model toward shorter lifetimes, since you only learn from customers who already left. Survival analysis handles censoring properly by using the survival function for these partial observations.&lt;/p&gt;

&lt;h3&gt;
  
  
  What is the difference between the Cox model and an AFT model?
&lt;/h3&gt;

&lt;p&gt;The Cox proportional hazards model is semi-parametric: it leaves the baseline hazard unspecified and estimates how covariates multiply the hazard rate. The accelerated failure time (AFT) model is fully parametric: it assumes a specific distribution (such as Weibull) and models how covariates accelerate or decelerate time to event. AFT coefficients have a direct interpretation as multipliers on median survival time, while Cox coefficients are hazard ratios.&lt;/p&gt;

&lt;h3&gt;
  
  
  What does pm.Potential do in PyMC?
&lt;/h3&gt;

&lt;p&gt;pm.Potential adds an arbitrary log-probability term directly to the model's log-posterior. For censored observations, there is no fully observed outcome to pass to a standard likelihood. Instead, you compute the log-survival probability and add it via pm.Potential, telling PyMC that these customers survived at least this long without specifying when they will actually churn.&lt;/p&gt;

&lt;h3&gt;
  
  
  How do I choose between Weibull and Log-Logistic distributions?
&lt;/h3&gt;

&lt;p&gt;Use Weibull when you expect the hazard rate to be monotonic, either always increasing, always decreasing, or constant over time. Use Log-Logistic when the hazard may be non-monotonic, such as high initial churn that drops as users engage and then rises again later. You can compare the two formally using LOO-CV (leave-one-out cross-validation) in ArviZ.&lt;/p&gt;

&lt;h3&gt;
  
  
  How many customers do I need for a Bayesian survival model?
&lt;/h3&gt;

&lt;p&gt;Bayesian models can work with surprisingly small datasets because priors regularise the estimates, but a practical minimum is a few hundred observations with at least 50 to 100 uncensored events. With heavy censoring (over 80% still active), the model has less information about event times, so you may need a larger sample or more informative priors to get precise estimates.&lt;/p&gt;

&lt;h3&gt;
  
  
  Can I add time-varying covariates to a Bayesian AFT model?
&lt;/h3&gt;

&lt;p&gt;The standard AFT formulation assumes covariates are fixed at baseline and does not naturally handle features that change during a customer's lifetime, such as monthly usage patterns. For time-varying covariates, you would need a piecewise AFT approach that splits each customer's timeline into intervals, or a joint model that links the longitudinal covariate process with the survival outcome.&lt;/p&gt;

</description>
      <category>bayesian</category>
      <category>probabilistic</category>
      <category>survivalanalysis</category>
      <category>pymc</category>
    </item>
    <item>
      <title>Hierarchical Bayesian Regression with PyMC: When Groups Share Strength</title>
      <dc:creator>Berkan Sesen</dc:creator>
      <pubDate>Sun, 26 Apr 2026 12:43:53 +0000</pubDate>
      <link>https://dev.to/berkan_sesen/hierarchical-bayesian-regression-with-pymc-when-groups-share-strength-2hag</link>
      <guid>https://dev.to/berkan_sesen/hierarchical-bayesian-regression-with-pymc-when-groups-share-strength-2hag</guid>
      <description>&lt;p&gt;A multi-line insurer writes auto, home, commercial property, and a dozen other policy types under one roof. Some lines see thousands of claims a year; others might see 50. Every actuary faces the same dilemma: train a separate pricing model for each line and the small ones are pure noise, or pool everything together and pretend a warehouse fire looks like a fender bender. Either way, you lose.&lt;/p&gt;

&lt;p&gt;Hierarchical Bayesian regression offers a third way. Each group gets its own parameters, but those parameters are drawn from a shared population distribution. Groups with plenty of data stay close to their own estimates. Groups with little data get "pulled" toward the population average, borrowing statistical strength from the larger groups. This effect is called &lt;strong&gt;shrinkage&lt;/strong&gt;, and it's one of the most elegant ideas in statistics.&lt;/p&gt;

&lt;p&gt;By the end of this post, you'll build a hierarchical Bayesian regression model in &lt;a href="https://www.pymc.io/" rel="noopener noreferrer"&gt;PyMC&lt;/a&gt;, compare it against pooled and unpooled alternatives, and see shrinkage in action on synthetic insurance data.&lt;/p&gt;

&lt;h2&gt;
  
  
  Let's Build It
&lt;/h2&gt;

&lt;p&gt;First, let's see the hierarchical model in action. Click the badge below to open the full interactive notebook:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://colab.research.google.com/github/zhubarb/sesen_ai_ml_tutorials/blob/main/notebooks/bayesian/hierarchical_bayesian_regression.ipynb" rel="noopener noreferrer"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fcolab.research.google.com%2Fassets%2Fcolab-badge.svg" alt="Open In Colab" width="117" height="20"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;We'll generate synthetic insurance claim data for three policy types with deliberately unbalanced sample sizes, then fit a hierarchical model.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fuj18jcvc1oul1myfok00.gif" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fuj18jcvc1oul1myfok00.gif" alt="The Commercial intercept posterior building up as MCMC samples accumulate. Early frames show a jagged histogram; later frames resolve to a smooth distribution centred near the true intercept value of 9.0." width="800" height="450"&gt;&lt;/a&gt;&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;pandas&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;pd&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;pymc&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;arviz&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;az&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;matplotlib.pyplot&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;

&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;seed&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Three policy types: lots of Auto data, moderate Home, very little Commercial
&lt;/span&gt;&lt;span class="n"&gt;groups&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Auto&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;       &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;n&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mi"&gt;500&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;intercept&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mf"&gt;7.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;slope&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mf"&gt;0.30&lt;/span&gt;&lt;span class="p"&gt;},&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Home&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;       &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;n&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mi"&gt;300&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;intercept&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mf"&gt;8.2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;slope&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mf"&gt;0.50&lt;/span&gt;&lt;span class="p"&gt;},&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Commercial&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;n&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;  &lt;span class="mi"&gt;50&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;intercept&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mf"&gt;9.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;slope&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mf"&gt;0.70&lt;/span&gt;&lt;span class="p"&gt;},&lt;/span&gt;
&lt;span class="p"&gt;}&lt;/span&gt;

&lt;span class="n"&gt;records&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;name&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;enumerate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;groups&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;items&lt;/span&gt;&lt;span class="p"&gt;()):&lt;/span&gt;
    &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;12&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;1.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;n&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;  &lt;span class="c1"&gt;# log property value (~$160k median)
&lt;/span&gt;    &lt;span class="n"&gt;noise&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;n&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
    &lt;span class="n"&gt;y&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;intercept&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;slope&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mi"&gt;12&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;noise&lt;/span&gt;
    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;j&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;n&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]):&lt;/span&gt;
        &lt;span class="n"&gt;records&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;({&lt;/span&gt;
            &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;policy_type&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;name&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;group_idx&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
            &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;log_property_value&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;j&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;log_claim_severity&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;j&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
        &lt;span class="p"&gt;})&lt;/span&gt;

&lt;span class="n"&gt;df&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pd&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;DataFrame&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;records&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fvdwzplqpuruf55us0di9.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fvdwzplqpuruf55us0di9.webp" alt="Scatter plot of log claim severity vs log property value, coloured by policy type. Auto (blue, n=500) and Home (orange, n=300) have dense clusters while Commercial (green, n=50) is sparse." width="800" height="493"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Each policy type has a different intercept and slope, but Commercial has just 50 data points. Now let's fit the hierarchical model:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;n_types&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;groups&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;idx&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;df&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;group_idx&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt;
&lt;span class="n"&gt;x_centered&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;df&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;log_property_value&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mi"&gt;12&lt;/span&gt;  &lt;span class="c1"&gt;# center the predictor
&lt;/span&gt;
&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Model&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;hierarchical_model&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="c1"&gt;# Hyperpriors: the "population" distribution that groups are drawn from
&lt;/span&gt;    &lt;span class="n"&gt;mu_alpha&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;mu_alpha&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;sigma_alpha&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;HalfNormal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;sigma_alpha&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;mu_beta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;mu_beta&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;sigma_beta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;HalfNormal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;sigma_beta&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Group-level parameters, drawn from the population
&lt;/span&gt;    &lt;span class="n"&gt;alpha&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;alpha&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;mu_alpha&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;sigma_alpha&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;n_types&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;beta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;beta&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;mu_beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;sigma_beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;n_types&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Observation noise
&lt;/span&gt;    &lt;span class="n"&gt;sigma&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;HalfNormal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;sigma&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Linear model
&lt;/span&gt;    &lt;span class="n"&gt;mu&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;idx&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;idx&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;x_centered&lt;/span&gt;
    &lt;span class="n"&gt;y_obs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;y_obs&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                      &lt;span class="n"&gt;observed&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;df&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;log_claim_severity&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Sample the posterior
&lt;/span&gt;    &lt;span class="n"&gt;hierarchical_trace&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
        &lt;span class="mi"&gt;1000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tune&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1500&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cores&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;chains&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;random_seed&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;target_accept&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.9&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Summarise the results
&lt;/span&gt;&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;az&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;summary&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hierarchical_trace&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;var_names&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;alpha&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;beta&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;sigma&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;You just estimated three group-specific regression lines (one per policy type) while letting them share statistical strength through a common population distribution. The Commercial group, despite having only 50 claims, gets a stable estimate because it borrows information from Auto and Home.&lt;/p&gt;

&lt;h2&gt;
  
  
  What Just Happened?
&lt;/h2&gt;

&lt;h3&gt;
  
  
  The Three Pooling Strategies
&lt;/h3&gt;

&lt;p&gt;To understand why the hierarchical model is special, let's compare it against the two extreme alternatives.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Complete pooling&lt;/strong&gt; ignores group differences entirely. One intercept, one slope for all 850 data points:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Model&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;pooled_model&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;alpha&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;alpha&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;beta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;beta&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;sigma&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;HalfNormal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;sigma&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="n"&gt;mu&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;x_centered&lt;/span&gt;
    &lt;span class="n"&gt;y_obs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;y_obs&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                      &lt;span class="n"&gt;observed&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;df&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;log_claim_severity&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="n"&gt;pooled_trace&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tune&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cores&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;chains&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_seed&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;strong&gt;No pooling&lt;/strong&gt; treats each group as completely independent. Three separate intercepts, three separate slopes, with no shared information:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Model&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;unpooled_model&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;alpha&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;alpha&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;n_types&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;beta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;beta&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;n_types&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;sigma&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;HalfNormal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;sigma&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="n"&gt;mu&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;idx&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;idx&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;x_centered&lt;/span&gt;
    &lt;span class="n"&gt;y_obs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;y_obs&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                      &lt;span class="n"&gt;observed&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;df&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;log_claim_severity&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="n"&gt;unpooled_trace&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tune&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cores&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;chains&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_seed&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fite20um2l4abv9expb61.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fite20um2l4abv9expb61.webp" alt="Three-panel comparison of regression lines. Left: complete pooling (one line through all data, clearly wrong for Commercial). Centre: no pooling (three independent lines, Commercial line is noisy). Right: partial pooling (three lines, Commercial is slightly pulled toward the others)." width="800" height="316"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The comparison reveals the key insight. Complete pooling gives a single line dominated by Auto and Home (which together make up 94% of the data), systematically underestimating Commercial's higher intercept and steeper slope. No pooling gives each group its own line, but Commercial's estimate is noisy because it only has 50 points. Partial pooling (the hierarchical model) sits between the two: each group gets its own line, but the lines are gently pulled toward the population average. Groups with little data get pulled more.&lt;/p&gt;

&lt;h3&gt;
  
  
  How Hyperpriors Create Partial Pooling
&lt;/h3&gt;

&lt;p&gt;The magic ingredient is the &lt;strong&gt;hyperpriors&lt;/strong&gt;: &lt;code&gt;mu_alpha&lt;/code&gt;, &lt;code&gt;sigma_alpha&lt;/code&gt;, &lt;code&gt;mu_beta&lt;/code&gt;, &lt;code&gt;sigma_beta&lt;/code&gt;. These define a "population distribution" from which group-level parameters are drawn.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Calpha_j%2520%255Csim%2520%255Cmathcal%257BN%257D%28%255Cmu_%255Calpha%252C%2520%255Csigma_%255Calpha%255E2%29%2520%255Cquad%2520%255Ctext%257Bfor%2520%257D%2520j%2520%253D%25201%252C%2520%255Cldots%252C%2520J" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Calpha_j%2520%255Csim%2520%255Cmathcal%257BN%257D%28%255Cmu_%255Calpha%252C%2520%255Csigma_%255Calpha%255E2%29%2520%255Cquad%2520%255Ctext%257Bfor%2520%257D%2520j%2520%253D%25201%252C%2520%255Cldots%252C%2520J" alt="equation" width="354" height="29"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Think of &lt;code&gt;$\mu_\alpha$&lt;/code&gt; as the average intercept across all policy types, and &lt;code&gt;$\sigma_\alpha$&lt;/code&gt; as how much the types are allowed to differ. If the data supports large differences, &lt;code&gt;$\sigma_\alpha$&lt;/code&gt; will be large and each group behaves almost independently (like no pooling). If the groups are similar, &lt;code&gt;$\sigma_\alpha$&lt;/code&gt; shrinks and the group estimates collapse toward the population mean (like complete pooling).&lt;/p&gt;

&lt;p&gt;The sampler learns &lt;code&gt;$\sigma_\alpha$&lt;/code&gt; from the data itself. You don't have to choose between pooling and no pooling; the model figures out the right amount of sharing automatically.&lt;/p&gt;

&lt;h3&gt;
  
  
  Shrinkage: The Key Insight
&lt;/h3&gt;

&lt;p&gt;Shrinkage is the defining feature of hierarchical models. Compare each group's raw sample mean (what you'd get from no pooling) to its hierarchical posterior mean:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F6di4mv60vkp1unyx0ahx.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F6di4mv60vkp1unyx0ahx.webp" alt="Shrinkage plot showing raw group means (circles) and hierarchical posterior means (triangles) for each policy type's intercept. The horizontal dashed line marks the population mean. Commercial moves the most toward the population mean; Auto barely moves." width="800" height="472"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Commercial's intercept gets pulled the most toward the population mean, because it has the least data and therefore the most uncertainty. Auto barely moves, because 500 data points leave little room for the prior to override the evidence. This is exactly the Bayesian compromise between prior and data that we explored in &lt;a href="https://sesen.ai/blog/from-mle-to-bayesian-inference" rel="noopener noreferrer"&gt;From MLE to Bayesian Inference&lt;/a&gt;.&lt;/p&gt;

&lt;h3&gt;
  
  
  MCMC Diagnostics with ArviZ
&lt;/h3&gt;

&lt;p&gt;Before trusting the results, we need to verify the sampler converged. ArviZ provides the standard toolkit:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;az&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot_trace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hierarchical_trace&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;var_names&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;alpha&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;beta&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;sigma&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tight_layout&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Ft1tllij7tg3kzv6mxw9b.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Ft1tllij7tg3kzv6mxw9b.webp" alt="ArviZ trace plots for the hierarchical model. Top row: alpha (three group posteriors and traces). Middle row: beta (three group posteriors and traces). Bottom row: sigma (shared noise parameter). All chains show stable mixing." width="800" height="570"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Three things to check:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Trace mixing&lt;/strong&gt;: The chains should look like "hairy caterpillars", bouncing randomly around a stable mean. If a chain gets stuck or drifts, something is wrong.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;R-hat&lt;/strong&gt; (the Gelman-Rubin statistic): Should be below 1.01 for every parameter. Values above 1.1 indicate the chains haven't converged to the same distribution.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Effective sample size (ESS)&lt;/strong&gt;: Should be at least 400 per chain. Low ESS means the samples are highly autocorrelated and the posterior estimates are unreliable.
&lt;/li&gt;
&lt;/ol&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;summary&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;az&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;summary&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hierarchical_trace&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;var_names&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;alpha&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;beta&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;sigma&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;summary&lt;/span&gt;&lt;span class="p"&gt;[[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;mean&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;sd&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;hdi_3%&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;hdi_97%&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;r_hat&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;ess_bulk&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]])&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;If you've worked through our &lt;a href="https://sesen.ai/blog/mcmc-metropolis-hastings-island-hopping-guide" rel="noopener noreferrer"&gt;MCMC Metropolis-Hastings&lt;/a&gt; tutorial, you'll recognise the core idea: the sampler explores the posterior by proposing moves and accepting or rejecting them. PyMC uses the NUTS sampler (No U-Turn Sampler), a sophisticated variant of Hamiltonian Monte Carlo that automatically tunes step sizes and trajectory lengths.&lt;/p&gt;

&lt;h2&gt;
  
  
  Going Deeper
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Why Not a Normal Likelihood?
&lt;/h3&gt;

&lt;p&gt;The model above uses a Normal likelihood, which assumes claim amounts are symmetric around the mean. In practice, insurance claims are &lt;strong&gt;heavy-tailed&lt;/strong&gt;: most claims are small, but a few are enormous. The original code I adapted for this tutorial used a &lt;a href="https://en.wikipedia.org/wiki/Laplace_distribution" rel="noopener noreferrer"&gt;Laplace likelihood&lt;/a&gt; to handle this:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Dy_i%2520%255Csim%2520%255Ctext%257BLaplace%257D%28%255Cmu_i%252C%2520b_i%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Dy_i%2520%255Csim%2520%255Ctext%257BLaplace%257D%28%255Cmu_i%252C%2520b_i%29" alt="equation" width="202" height="25"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The Laplace distribution has heavier tails than the Normal and is more robust to outliers. In PyMC, swapping the likelihood is a single line change:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Replace:  pm.Normal('y_obs', mu=mu, sigma=sigma, observed=y)
# With:     pm.Laplace('y_obs', mu=mu, b=b, observed=y)
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  Modelling the Spread Too: Heteroscedastic Regression
&lt;/h3&gt;

&lt;p&gt;The original code goes further. It models &lt;strong&gt;both&lt;/strong&gt; the location &lt;code&gt;$\mu$&lt;/code&gt; and the scale &lt;code&gt;$b$&lt;/code&gt; of the Laplace distribution as functions of the covariates. This is heteroscedastic regression: the amount of noise varies across observations.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Cmu_i%2520%253D%2520%255Cexp%255C%21%255Cleft%28%255Cbeta_0%255E%257B%28j%29%257D%2520%252B%2520%255Cbeta_1%255E%257B%28j%29%257D%2520%255Clog%2520x_%257Bi1%257D%2520%252B%2520%255Ccdots%2520%252B%2520%255Cbeta_4%255E%257B%28j%29%257D%2520%255Clog%2520x_%257Bi4%257D%255Cright%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Cmu_i%2520%253D%2520%255Cexp%255C%21%255Cleft%28%255Cbeta_0%255E%257B%28j%29%257D%2520%252B%2520%255Cbeta_1%255E%257B%28j%29%257D%2520%255Clog%2520x_%257Bi1%257D%2520%252B%2520%255Ccdots%2520%252B%2520%255Cbeta_4%255E%257B%28j%29%257D%2520%255Clog%2520x_%257Bi4%257D%255Cright%29" alt="equation" width="490" height="45"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Db_i%2520%253D%2520%255Cexp%255C%21%255Cleft%28%255Cgamma_0%255E%257B%28j%29%257D%2520%252B%2520%255Cgamma_1%255E%257B%28j%29%257D%2520%255Clog%2520x_%257Bi1%257D%2520%252B%2520%255Ccdots%2520%252B%2520%255Cgamma_4%255E%257B%28j%29%257D%2520%255Clog%2520x_%257Bi4%257D%255Cright%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Db_i%2520%253D%2520%255Cexp%255C%21%255Cleft%28%255Cgamma_0%255E%257B%28j%29%257D%2520%252B%2520%255Cgamma_1%255E%257B%28j%29%257D%2520%255Clog%2520x_%257Bi1%257D%2520%252B%2520%255Ccdots%2520%252B%2520%255Cgamma_4%255E%257B%28j%29%257D%2520%255Clog%2520x_%257Bi4%257D%255Cright%29" alt="equation" width="481" height="45"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The &lt;code&gt;$\exp$&lt;/code&gt; ensures both &lt;code&gt;$\mu$&lt;/code&gt; and &lt;code&gt;$b$&lt;/code&gt; are positive (claim severity can't be negative). Each &lt;code&gt;$\beta$&lt;/code&gt; and &lt;code&gt;$\gamma$&lt;/code&gt; coefficient gets its own hierarchical structure:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Model&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;full_model&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;n_groups&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;  &lt;span class="c1"&gt;# policy types
&lt;/span&gt;
    &lt;span class="c1"&gt;# Hyperpriors for intercept
&lt;/span&gt;    &lt;span class="n"&gt;beta0_mu&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;beta0_mu&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;beta0_sig&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;InverseGamma&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;beta0_sig&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Group-level intercepts
&lt;/span&gt;    &lt;span class="n"&gt;beta0&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;beta0&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;beta0_mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;beta0_sig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;n_groups&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# ... repeat for each coefficient and for gamma (scale) parameters ...
&lt;/span&gt;
    &lt;span class="n"&gt;mu&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;beta0&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;group_idx&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;beta1&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;group_idx&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;...)&lt;/span&gt;
    &lt;span class="n"&gt;b&lt;/span&gt;  &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;gamma0&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;group_idx&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;gamma1&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;group_idx&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;...)&lt;/span&gt;

    &lt;span class="n"&gt;y_obs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Laplace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;y_obs&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;observed&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Notice the &lt;code&gt;pm.InverseGamma&lt;/code&gt; hyperprior for the variance parameters. The InverseGamma is the conjugate prior for Normal variance, making it a natural choice. With &lt;code&gt;alpha=2, beta=5&lt;/code&gt;, it places mass on moderate variance values while allowing large ones.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Three-Tier Model
&lt;/h3&gt;

&lt;p&gt;The code also contains a three-tier hierarchy. Instead of just grouping by policy type, it nests policy type within region:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;Population → Policy Type → (Region × Policy Type)
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;At the top level, hyper-hyperpriors define the global population. At the middle level, each policy type gets its own parameters drawn from the population. At the bottom level, each (region, policy type) combination gets parameters drawn from its policy type's distribution. The group-level parameters become 2D arrays with shape &lt;code&gt;(n_regions, n_policy_types)&lt;/code&gt;:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Hyper-hyperpriors (population level)
&lt;/span&gt;&lt;span class="n"&gt;beta0_mu_mu&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;beta0_mu_mu&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;beta0_mu_sig&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;InverseGamma&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;beta0_mu_sig&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Hyperpriors (policy type level)
&lt;/span&gt;&lt;span class="n"&gt;beta0_mu&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;beta0_mu&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;beta0_mu_mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;beta0_mu_sig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;n_types&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;beta0_sig&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;InverseGamma&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;beta0_sig&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;n_types&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Priors (region × policy type level)
&lt;/span&gt;&lt;span class="n"&gt;beta0&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;beta0&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;beta0_mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;beta0_sig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_regions&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_types&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This allows a commercial policy in an urban area to differ from one in a suburban area, while both borrow strength from the overall commercial distribution, which itself borrows from the global population.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fkpwi3jds0fdvqi0whtc1.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fkpwi3jds0fdvqi0whtc1.webp" alt="Diagram of the two-tier hierarchical structure: population hyperpriors at the top feeding into policy-type parameters (Auto, Home, Commercial) in the middle, which govern the observed data at the bottom. The wider arrow to Commercial indicates more shrinkage due to its smaller sample size." width="800" height="446"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3&gt;
  
  
  When Not to Use Hierarchical Models
&lt;/h3&gt;

&lt;p&gt;Hierarchical models aren't always necessary. If every group has plenty of data (thousands of observations), no pooling gives nearly identical results to partial pooling because the data overwhelms the prior. The hierarchy adds complexity and sampling time for little benefit.&lt;/p&gt;

&lt;p&gt;They can also struggle with very few groups. With only 2 groups, the hyperprior variance &lt;code&gt;$\sigma_\alpha$&lt;/code&gt; is estimated from just 2 data points (the two group-level parameters), making it unreliable. Most practitioners suggest hierarchical models shine with 5 or more groups, though the exact threshold depends on within-group sample sizes.&lt;/p&gt;

&lt;h2&gt;
  
  
  Where This Comes From
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Lindley and Smith (1972)
&lt;/h3&gt;

&lt;p&gt;The mathematical foundation was laid by Dennis Lindley and Adrian Smith in their 1972 paper "Bayes Estimates for the Linear Model." They formalised the multi-level Normal model:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Cmathbf%257By%257D%2520%255Cmid%2520%255Cboldsymbol%257B%255Ctheta%257D%2520%255Csim%2520%255Cmathcal%257BN%257D%28A%255Cboldsymbol%257B%255Ctheta%257D%252C%255C%252C%2520C%29%2520%255Cqquad%2520%255Cboldsymbol%257B%255Ctheta%257D%2520%255Cmid%2520%255Cboldsymbol%257B%255Cmu%257D%2520%255Csim%2520%255Cmathcal%257BN%257D%28B%255Cboldsymbol%257B%255Cmu%257D%252C%255C%252C%2520D%29%2520%255Cqquad%2520%255Cboldsymbol%257B%255Cmu%257D%2520%255Csim%2520%255Cmathcal%257BN%257D%28%255Cmathbf%257B0%257D%252C%255C%252C%2520E%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Cmathbf%257By%257D%2520%255Cmid%2520%255Cboldsymbol%257B%255Ctheta%257D%2520%255Csim%2520%255Cmathcal%257BN%257D%28A%255Cboldsymbol%257B%255Ctheta%257D%252C%255C%252C%2520C%29%2520%255Cqquad%2520%255Cboldsymbol%257B%255Ctheta%257D%2520%255Cmid%2520%255Cboldsymbol%257B%255Cmu%257D%2520%255Csim%2520%255Cmathcal%257BN%257D%28B%255Cboldsymbol%257B%255Cmu%257D%252C%255C%252C%2520D%29%2520%255Cqquad%2520%255Cboldsymbol%257B%255Cmu%257D%2520%255Csim%2520%255Cmathcal%257BN%257D%28%255Cmathbf%257B0%257D%252C%255C%252C%2520E%29" alt="equation" width="636" height="26"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The key result: the posterior mean of &lt;code&gt;$\boldsymbol{\theta}$&lt;/code&gt; is a &lt;strong&gt;matrix-weighted average&lt;/strong&gt; of the group-specific &lt;a href="https://sesen.ai/blog/maximum-likelihood-estimation-from-scratch" rel="noopener noreferrer"&gt;MLE&lt;/a&gt; and the prior mean. Groups with more data (higher precision in &lt;code&gt;$C^{-1}$&lt;/code&gt;) weight their own MLE more heavily; groups with less data lean more on the prior. This is the formal statement of shrinkage.&lt;/p&gt;

&lt;h3&gt;
  
  
  Efron and Morris (1977): The James-Stein Connection
&lt;/h3&gt;

&lt;p&gt;The frequentist justification for shrinkage came from an unexpected direction. In 1977, Brad Efron and Carl Morris showed that the James-Stein estimator (which shrinks group means toward the grand mean) &lt;strong&gt;dominates&lt;/strong&gt; the usual sample means in terms of total squared error, for three or more groups simultaneously. This was a shocking result: even if the groups have nothing in common, shrinking toward their average reduces total estimation error.&lt;/p&gt;

&lt;blockquote&gt;
&lt;p&gt;"The James-Stein estimator achieves a smaller total mean squared error than the individual sample means, for any configuration of the true means, provided there are three or more groups."&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;The hierarchical Bayesian model produces estimates that are closely related to the James-Stein estimator. The Bayesian framework provides a natural explanation: when data is scarce, it's rational to hedge toward the population average rather than fully committing to a noisy local estimate.&lt;/p&gt;

&lt;h3&gt;
  
  
  Gelman and Hill (2006)
&lt;/h3&gt;

&lt;p&gt;The practical handbook for hierarchical models is Andrew Gelman and Jennifer Hill's &lt;em&gt;Data Analysis Using Regression and Multilevel/Hierarchical Models&lt;/em&gt;. Chapter 12 presents the exact three-model comparison we built above (complete pooling, no pooling, partial pooling) using radon measurements across US counties. Their formulation uses the non-centred parameterisation:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Calpha_j%2520%253D%2520%255Cmu_%255Calpha%2520%252B%2520%255Csigma_%255Calpha%2520%255Ccdot%2520%255Ceta_j%252C%2520%255Cquad%2520%255Ceta_j%2520%255Csim%2520%255Cmathcal%257BN%257D%280%252C%25201%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Calpha_j%2520%253D%2520%255Cmu_%255Calpha%2520%252B%2520%255Csigma_%255Calpha%2520%255Ccdot%2520%255Ceta_j%252C%2520%255Cquad%2520%255Ceta_j%2520%255Csim%2520%255Cmathcal%257BN%257D%280%252C%25201%29" alt="equation" width="346" height="28"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;This reparameterisation often improves MCMC sampling efficiency because the sampler explores a standard Normal geometry rather than a funnel-shaped one. PyMC can apply this transformation automatically, but it's worth knowing when your model has divergences.&lt;/p&gt;

&lt;p&gt;Gelman et al.'s &lt;em&gt;Bayesian Data Analysis&lt;/em&gt; (3rd edition, 2013) provides the full mathematical treatment in Chapter 5, including the relationship between hierarchical Bayes, empirical Bayes, and the James-Stein estimator.&lt;/p&gt;

&lt;h3&gt;
  
  
  Further Reading
&lt;/h3&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;The original formalism:&lt;/strong&gt; Lindley, D. V. &amp;amp; Smith, A. F. M. (1972). "Bayes estimates for the linear model." &lt;em&gt;Journal of the Royal Statistical Society: Series B&lt;/em&gt;, 34(1), 1-41.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;The James-Stein connection:&lt;/strong&gt; Efron, B. &amp;amp; Morris, C. (1977). "Stein's paradox in statistics." &lt;em&gt;Scientific American&lt;/em&gt;, 236(5), 119-127.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;The practical handbook:&lt;/strong&gt; Gelman, A. &amp;amp; Hill, J. (2006). &lt;em&gt;Data Analysis Using Regression and Multilevel/Hierarchical Models&lt;/em&gt;. Cambridge University Press. Chapters 11-13.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;The full Bayesian treatment:&lt;/strong&gt; Gelman, A. et al. (2013). &lt;em&gt;Bayesian Data Analysis&lt;/em&gt;, 3rd ed. CRC Press. Chapter 5.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;PyMC documentation:&lt;/strong&gt; &lt;a href="https://www.pymc.io/projects/examples/en/latest/generalized_linear_models/GLM-hierarchical.html" rel="noopener noreferrer"&gt;PyMC Hierarchical Models tutorial&lt;/a&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Next post in this series:&lt;/strong&gt; Bayesian Survival Analysis, where we extend PyMC to handle censored data using &lt;code&gt;pm.Potential&lt;/code&gt;.&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Interactive Tools
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/bayes-theorem-calculator" rel="noopener noreferrer"&gt;Bayes' Theorem Calculator&lt;/a&gt; — Explore Bayesian updating interactively before diving into hierarchical models&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/ab-test-calculator" rel="noopener noreferrer"&gt;A/B Test Calculator&lt;/a&gt; — See Bayesian hypothesis testing in action, a common application of hierarchical models&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Related Posts
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/from-mle-to-bayesian-inference" rel="noopener noreferrer"&gt;From MLE to Bayesian Inference&lt;/a&gt;: The conceptual foundation for priors, posteriors, and why Bayesian estimates beat point estimates.&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/mcmc-metropolis-hastings-island-hopping-guide" rel="noopener noreferrer"&gt;MCMC Island Hopping: Understanding Metropolis-Hastings&lt;/a&gt;: How the sampler that powers PyMC actually explores the posterior distribution.&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/linear-regression-five-ways" rel="noopener noreferrer"&gt;Linear Regression: Five Ways&lt;/a&gt;: The non-hierarchical regression baseline that this post extends with group structure.&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Frequently Asked Questions
&lt;/h2&gt;

&lt;h3&gt;
  
  
  What is hierarchical Bayesian regression?
&lt;/h3&gt;

&lt;p&gt;Hierarchical (or multilevel) regression models data that is naturally grouped (students within schools, patients within hospitals) by allowing parameters to vary across groups while sharing a common prior distribution. This "partial pooling" approach borrows strength across groups, producing better estimates for small groups than fitting each group independently.&lt;/p&gt;

&lt;h3&gt;
  
  
  What is the difference between complete pooling, no pooling, and partial pooling?
&lt;/h3&gt;

&lt;p&gt;Complete pooling ignores group differences entirely (one model for all). No pooling fits a separate model per group (no information sharing). Partial pooling (hierarchical) sits in between: each group gets its own parameters, but they are pulled towards a shared distribution. This is especially valuable when some groups have very few observations.&lt;/p&gt;

&lt;h3&gt;
  
  
  Why use PyMC for hierarchical models?
&lt;/h3&gt;

&lt;p&gt;PyMC uses MCMC sampling to handle the complex posterior distributions that hierarchical models produce. It naturally propagates uncertainty through all levels of the hierarchy. Frequentist alternatives (like lme4 in R) can fit similar models but do not provide the same rich uncertainty quantification or flexibility for custom model structures.&lt;/p&gt;

&lt;h3&gt;
  
  
  How do I diagnose convergence in PyMC?
&lt;/h3&gt;

&lt;p&gt;Check the trace plots for good mixing (no trends, no stuck chains), verify that R-hat values are close to 1.0 (below 1.01), and ensure effective sample sizes are sufficiently large (at least 400 per chain). Divergent transitions indicate the sampler is struggling with the posterior geometry and may require reparameterisation.&lt;/p&gt;

&lt;h3&gt;
  
  
  When should I use a hierarchical model instead of a standard regression?
&lt;/h3&gt;

&lt;p&gt;Use hierarchical models whenever your data has a natural grouping structure and you want to make inferences about individual groups. They are especially valuable when group sizes are unequal: small groups benefit from borrowing strength, and large groups are barely affected by the pooling. If all groups have abundant data, the results will be similar to fitting separate models.&lt;/p&gt;

</description>
      <category>bayesian</category>
      <category>probabilistic</category>
      <category>inference</category>
      <category>pymc</category>
    </item>
    <item>
      <title>Solving CartPole Without Gradients: Simulated Annealing</title>
      <dc:creator>Berkan Sesen</dc:creator>
      <pubDate>Thu, 23 Apr 2026 07:51:02 +0000</pubDate>
      <link>https://dev.to/berkan_sesen/solving-cartpole-without-gradients-simulated-annealing-3e47</link>
      <guid>https://dev.to/berkan_sesen/solving-cartpole-without-gradients-simulated-annealing-3e47</guid>
      <description>&lt;p&gt;In the &lt;a href="https://sesen.ai/blog/cross-entropy-method-evolution-style-rl" rel="noopener noreferrer"&gt;previous post&lt;/a&gt;, we solved CartPole using the Cross-Entropy Method: sample 200 candidate policies, keep the best 40, refit a Gaussian, repeat. It worked beautifully, reaching a perfect score of 500 in 50 iterations. But 200 candidates per iteration means 10,000 total episode evaluations. That got me wondering: do we really need a population of 200 to find four good numbers?&lt;/p&gt;

&lt;p&gt;The original code that inspired this post took a radically simpler approach. Instead of maintaining a population, it kept a single set of parameters and perturbed them once per iteration. If the perturbation improved the score, it was accepted and the perturbation range was shrunk. That's it. No population, no distribution fitting, no gradients. The comment in the source file read: "its like simulated annealing." By the end of this post, you'll implement this algorithm from scratch, solve CartPole-v1 with a perfect 500 score, and understand how it connects to the rich theory of simulated annealing.&lt;/p&gt;

&lt;h2&gt;
  
  
  Let's Build It
&lt;/h2&gt;

&lt;p&gt;Click the badge to open the interactive notebook:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://colab.research.google.com/github/zhubarb/sesen_ai_ml_tutorials/blob/main/notebooks/reinforcement-learning/simulated_annealing_cartpole.ipynb" rel="noopener noreferrer"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fcolab.research.google.com%2Fassets%2Fcolab-badge.svg" alt="Open In Colab" width="117" height="20"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fsrnztpxlba6spbfbo3t5.gif" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fsrnztpxlba6spbfbo3t5.gif" alt="Simulated annealing convergence animation: best score climbs from ~10 to 500 by iteration 41, then holds steady" width="800" height="400"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Here's the complete implementation. Like CEM, we use a linear policy with 4 parameters (one per observation dimension). But instead of sampling a population, we perturb a single solution:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;gymnasium&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;evaluate_policy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env_name&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;theta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_episodes&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;Run multiple episodes with a linear policy and return the average reward.&lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
    &lt;span class="n"&gt;total_reward&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_episodes&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;env&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;make&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env_name&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;reset&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;episode_reward&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
        &lt;span class="n"&gt;done&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;False&lt;/span&gt;
        &lt;span class="k"&gt;while&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="n"&gt;done&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="n"&gt;action&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;dot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;theta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;&amp;gt;&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
            &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="n"&gt;episode_reward&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;
            &lt;span class="n"&gt;done&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt; &lt;span class="ow"&gt;or&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;
        &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;close&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;total_reward&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;episode_reward&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;total_reward&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;n_episodes&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;simulated_annealing&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env_name&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_params&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_iter&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;80&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_eval_episodes&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                        &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;decay&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.9&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;Hill climbing with annealing step size for policy search.&lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
    &lt;span class="n"&gt;best_theta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_params&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;best_score&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;evaluate_policy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env_name&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;best_theta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_eval_episodes&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_iter&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="c1"&gt;# Perturb current best (uniform noise scaled by alpha)
&lt;/span&gt;        &lt;span class="n"&gt;perturbation&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;rand&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_params&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;
        &lt;span class="n"&gt;candidate&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;best_theta&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;perturbation&lt;/span&gt;

        &lt;span class="c1"&gt;# Evaluate candidate over multiple episodes
&lt;/span&gt;        &lt;span class="n"&gt;score&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;evaluate_policy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env_name&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;candidate&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_eval_episodes&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="c1"&gt;# Accept only if better, then shrink step size
&lt;/span&gt;        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;score&lt;/span&gt; &lt;span class="o"&gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;best_score&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="n"&gt;best_theta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;candidate&lt;/span&gt;
            &lt;span class="n"&gt;best_score&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;score&lt;/span&gt;
            &lt;span class="n"&gt;alpha&lt;/span&gt; &lt;span class="o"&gt;*=&lt;/span&gt; &lt;span class="n"&gt;decay&lt;/span&gt;

        &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Iter &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="n"&gt;d&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; | Score: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;score&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; | Best: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;best_score&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; | Alpha: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;best_theta&lt;/span&gt;

&lt;span class="n"&gt;best_theta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;simulated_annealing&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;CartPole-v1&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_params&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="c1"&gt;# Iter   1 | Score:   9.6 | Best:   9.6 | Alpha: 1.0000
# Iter   9 | Score: 128.7 | Best: 128.7 | Alpha: 0.6561
# Iter  14 | Score: 314.2 | Best: 314.2 | Alpha: 0.5314
# Iter  24 | Score: 465.7 | Best: 465.7 | Alpha: 0.4783
# Iter  41 | Score: 500.0 | Best: 500.0 | Alpha: 0.3874
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Perfect score in 41 iterations. Let's verify with 100 evaluation episodes:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;scores&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="nf"&gt;evaluate_policy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;CartPole-v1&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;best_theta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_episodes&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;)]&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Mean: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;scores&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; +/- &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;std&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;scores&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="c1"&gt;# Mean: 496 +/- 12
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Four parameters, zero gradients, 800 total episode evaluations. Compare that to &lt;a href="https://sesen.ai/blog/cross-entropy-method-evolution-style-rl" rel="noopener noreferrer"&gt;CEM&lt;/a&gt;'s 10,000 episodes or &lt;a href="https://sesen.ai/blog/policy-gradients-reinforce-from-scratch" rel="noopener noreferrer"&gt;REINFORCE&lt;/a&gt;'s 5,000.&lt;/p&gt;

&lt;h2&gt;
  
  
  What Just Happened?
&lt;/h2&gt;

&lt;p&gt;The algorithm maintains a single candidate solution and improves it through a cycle of perturb, evaluate, and accept. Here's the full loop:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fe4tvxovb8mfum7rz4und.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fe4tvxovb8mfum7rz4und.webp" alt="SA algorithm flow: start with zeros, perturb with noise scaled by alpha, evaluate over 10 episodes, accept if better (shrink alpha) or reject (keep current)" width="800" height="446"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Let's walk through each piece.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Linear Policy
&lt;/h3&gt;

&lt;p&gt;Just like in the CEM post, CartPole has a 4-dimensional observation vector (cart position, cart velocity, pole angle, pole angular velocity). Our policy is a simple dot product:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;action&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;dot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;theta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;&amp;gt;&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This is a linear classifier: push right if the weighted sum of observations is positive, push left otherwise. The entire "intelligence" of the agent lives in four numbers.&lt;/p&gt;

&lt;h3&gt;
  
  
  Multi-Episode Evaluation
&lt;/h3&gt;

&lt;p&gt;The original code's key insight (noted in a comment: "key thing was to figure out that you need to do 10 tests per point") is to evaluate each candidate over 10 episodes and average the scores. CartPole has stochastic initial conditions, so a single episode can be misleading. A policy might score 500 on one lucky initialisation and 50 on the next. Averaging over 10 episodes gives a stable estimate of true quality.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;score&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;evaluate_policy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env_name&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;candidate&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_eval_episodes&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  The Perturbation Step
&lt;/h3&gt;

&lt;p&gt;Each iteration, we perturb the current best parameters with uniform noise scaled by &lt;code&gt;alpha&lt;/code&gt;:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;perturbation&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;rand&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_params&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;
&lt;span class="n"&gt;candidate&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;best_theta&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;perturbation&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;When &lt;code&gt;alpha=1.0&lt;/code&gt;, each parameter can change by up to &lt;code&gt;$\pm 0.5$&lt;/code&gt;. As alpha shrinks, the perturbations get smaller, focusing the search around the current best.&lt;/p&gt;

&lt;h3&gt;
  
  
  Accept and Anneal
&lt;/h3&gt;

&lt;p&gt;Here's the crucial part. We only accept improvements, and we only shrink the step size when we find one:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;score&lt;/span&gt; &lt;span class="o"&gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;best_score&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;best_theta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;candidate&lt;/span&gt;
    &lt;span class="n"&gt;best_score&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;score&lt;/span&gt;
    &lt;span class="n"&gt;alpha&lt;/span&gt; &lt;span class="o"&gt;*=&lt;/span&gt; &lt;span class="n"&gt;decay&lt;/span&gt;  &lt;span class="c1"&gt;# Shrink step size by 10%
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This is an adaptive cooling schedule. If the algorithm keeps finding improvements, alpha decays quickly (&lt;code&gt;$0.9^9 \approx 0.39$&lt;/code&gt; after 9 improvements). If it gets stuck, alpha stays large, maintaining exploration. The algorithm found 9 improvements out of 80 iterations, ending with &lt;code&gt;$\alpha = 0.387$&lt;/code&gt;.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Training Curve
&lt;/h3&gt;

&lt;p&gt;The staircase pattern tells the story. Each vertical jump is an accepted improvement; each flat region is the algorithm searching without finding anything better:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;subplots&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;ax1&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;scatter&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;iters&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:],&lt;/span&gt; &lt;span class="n"&gt;candidate_scores&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:],&lt;/span&gt;
            &lt;span class="n"&gt;c&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;#2ecc71&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;a&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;#e74c3c&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;a&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;accepted&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:]],&lt;/span&gt;
            &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;20&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Candidates&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax1&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;iters&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;best_scores&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;b-&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;linewidth&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Best score&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax1&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;axhline&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;500&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;k&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;linestyle&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;:&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Max possible (500)&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;ax2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;ax1&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;twinx&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;ax2&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;iters&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alphas&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;k--&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;linewidth&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Step size (α)&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax2&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_ylabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Step size (α)&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;gray&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fl5cgzdvgyy2y79rw9m2y.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fl5cgzdvgyy2y79rw9m2y.webp" alt="SA training curve showing staircase improvements with candidate scores as coloured dots and step size decay on secondary axis" width="800" height="394"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Green dots are accepted candidates (improvements); red dots are rejected ones. The dashed grey line shows the step size &lt;code&gt;$\alpha$&lt;/code&gt; shrinking on the secondary axis. Notice how the red dots cluster higher as the search progresses, because even rejected perturbations from a good solution tend to produce decent policies.&lt;/p&gt;

&lt;h2&gt;
  
  
  Going Deeper
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Hill Climbing vs True Simulated Annealing
&lt;/h3&gt;

&lt;p&gt;Let's be precise about what our algorithm is. The original code's comment called it "like simulated annealing," and that's accurate, but with an important distinction.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Our algorithm (hill climbing with annealing step size):&lt;/strong&gt;&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;Accepts only improvements&lt;/li&gt;
&lt;li&gt;Shrinks the step size when an improvement is found&lt;/li&gt;
&lt;li&gt;Never accepts a worse solution&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;strong&gt;True simulated annealing:&lt;/strong&gt;&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;Accepts improvements always&lt;/li&gt;
&lt;li&gt;Accepts worse solutions with probability &lt;code&gt;$e^{-\Delta E / T}$&lt;/code&gt;
&lt;/li&gt;
&lt;li&gt;Shrinks the temperature &lt;code&gt;$T$&lt;/code&gt; on a fixed schedule&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;The difference is in how they handle worse solutions. True SA occasionally accepts a downhill move, which allows it to escape local optima. Our algorithm never does, which makes it a strict hill climber. The "annealing" part is only in the step size, not in the acceptance criterion.&lt;/p&gt;

&lt;p&gt;For CartPole with a 4-parameter linear policy, this distinction doesn't matter: the reward landscape is smooth enough that hill climbing works. For harder problems with many local optima, true SA's ability to escape traps becomes essential.&lt;/p&gt;

&lt;p&gt;If you've read the &lt;a href="https://sesen.ai/blog/mcmc-metropolis-hastings-island-hopping-guide" rel="noopener noreferrer"&gt;MCMC Metropolis-Hastings post&lt;/a&gt;, the acceptance criterion should look familiar. The Metropolis acceptance probability &lt;code&gt;$\min(1, e^{-\Delta E / T})$&lt;/code&gt; is exactly what true SA uses. In MCMC, we want to sample from a distribution; in SA, we want to find its peak. Same mechanism, different goal.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Cooling Schedule
&lt;/h3&gt;

&lt;p&gt;Our algorithm uses a multiplicative decay: &lt;code&gt;$\alpha_{t+1} = 0.9 \cdot \alpha_t$&lt;/code&gt; on each improvement. This creates a geometric sequence:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Calpha_k%2520%253D%2520%255Calpha_0%2520%255Ccdot%2520%255Cgamma%255Ek%2520%253D%25201.0%2520%255Ccdot%25200.9%255Ek" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Calpha_k%2520%253D%2520%255Calpha_0%2520%255Ccdot%2520%255Cgamma%255Ek%2520%253D%25201.0%2520%255Ccdot%25200.9%255Ek" alt="equation" width="249" height="28"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;where &lt;code&gt;$k$&lt;/code&gt; is the number of improvements found. After 9 improvements, &lt;code&gt;$\alpha = 0.9^9 \approx 0.387$&lt;/code&gt;.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;ax1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;subplots&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;12&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;ax1&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;iters&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alphas&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;b-&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;linewidth&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;       &lt;span class="c1"&gt;# Alpha vs iterations
&lt;/span&gt;&lt;span class="n"&gt;ax2&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;steps&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.9&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="n"&gt;steps&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;linewidth&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;          &lt;span class="c1"&gt;# Geometric decay curves
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fs92kvrqg9qt5eryy4660.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fs92kvrqg9qt5eryy4660.webp" alt="Cooling schedule: left panel shows step size over iterations with green bars marking improvements; right panel compares geometric decay rates" width="800" height="261"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The left panel shows alpha over iterations, with green bands marking accepted improvements. The right panel compares different decay rates. A faster decay (&lt;code&gt;$\gamma = 0.8$&lt;/code&gt;) converges to fine-tuning quickly but risks getting stuck. A slower decay (&lt;code&gt;$\gamma = 0.95$&lt;/code&gt;) explores longer but takes more iterations to refine. The original code's choice of 0.9 strikes a reasonable balance.&lt;/p&gt;

&lt;p&gt;What makes our schedule adaptive is that it only decays on improvement. Traditional SA uses fixed schedules (logarithmic, linear, or exponential decay in wall-clock time). Our variant keeps &lt;code&gt;$\alpha$&lt;/code&gt; large during plateaus, naturally spending more time exploring when stuck and more time refining when making progress.&lt;/p&gt;

&lt;h3&gt;
  
  
  SA vs CEM: One Climber vs a Search Party
&lt;/h3&gt;

&lt;p&gt;The &lt;a href="https://sesen.ai/blog/cross-entropy-method-evolution-style-rl" rel="noopener noreferrer"&gt;Cross-Entropy Method&lt;/a&gt; we built last time and simulated annealing sit at opposite ends of the derivative-free spectrum:&lt;/p&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Aspect&lt;/th&gt;
&lt;th&gt;Simulated Annealing&lt;/th&gt;
&lt;th&gt;CEM&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;Search strategy&lt;/td&gt;
&lt;td&gt;Single point, local perturbations&lt;/td&gt;
&lt;td&gt;Population of 200, distribution fitting&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Episodes per iteration&lt;/td&gt;
&lt;td&gt;10&lt;/td&gt;
&lt;td&gt;200 (200 candidates x 1 each)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Total episodes to solve CartPole&lt;/td&gt;
&lt;td&gt;~800&lt;/td&gt;
&lt;td&gt;~10,000 (200 x 50 iterations)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Information used&lt;/td&gt;
&lt;td&gt;"Is this better than the best?" (1 bit)&lt;/td&gt;
&lt;td&gt;Full reward ranking of all candidates&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Robustness&lt;/td&gt;
&lt;td&gt;Seed-dependent; some runs may fail&lt;/td&gt;
&lt;td&gt;Highly robust; population averages out noise&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Parallelisable&lt;/td&gt;
&lt;td&gt;No (sequential by nature)&lt;/td&gt;
&lt;td&gt;Yes (all 200 evaluations are independent)&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;p&gt;SA is like a single hiker exploring a mountain range, taking one step at a time and only moving to higher ground. CEM is like sending 200 hikers, ranking them by altitude, and teleporting the next batch to the region where the best ones clustered.&lt;/p&gt;

&lt;p&gt;SA wins on sample efficiency (fewer total episodes) but loses on reliability. Run SA with a different random seed and you might need 20 iterations or 200. CEM's population averaging makes it much more consistent.&lt;/p&gt;

&lt;h3&gt;
  
  
  SA vs Random Search
&lt;/h3&gt;

&lt;p&gt;How much does the "annealing" (building on previous improvements) actually help, compared to just sampling random policies each time?&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;subplots&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;iters&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sa_best_scores&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;b-&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;linewidth&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Simulated annealing&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;iters&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_best_scores&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;r--&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;linewidth&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Random search&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Frdd467x63atvwvt9e5kx.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Frdd467x63atvwvt9e5kx.webp" alt="SA reaching 500 while random search plateaus at 387 after 80 iterations" width="800" height="394"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Random search samples a fresh random policy each iteration (uniform in &lt;code&gt;$[-1, 1]^4$&lt;/code&gt;) and tracks the best one found. After 80 iterations, its best score is 387 vs SA's 500. Random search got lucky once (iteration 2) and found a decent policy early, but it can never refine it. SA's ability to make small improvements to an already-good solution is what pushes it from "decent" to "perfect."&lt;/p&gt;

&lt;h3&gt;
  
  
  Hyperparameters
&lt;/h3&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Parameter&lt;/th&gt;
&lt;th&gt;Value&lt;/th&gt;
&lt;th&gt;Effect&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;alpha&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;1.0&lt;/td&gt;
&lt;td&gt;Initial step size. Perturbations range in &lt;code&gt;$[-0.5, 0.5]$&lt;/code&gt; per parameter&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;decay&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;0.9&lt;/td&gt;
&lt;td&gt;Step size multiplier on improvement. Lower = faster convergence, less exploration&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;n_iter&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;80&lt;/td&gt;
&lt;td&gt;Total iterations. Our run converged at iteration 41&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;n_eval_episodes&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;10&lt;/td&gt;
&lt;td&gt;Episodes per evaluation. More = less noise, more compute&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;p&gt;The most sensitive parameter is &lt;code&gt;decay&lt;/code&gt;. At 0.9, alpha halves after about 7 improvements. At 0.8, it halves after 4. Too aggressive and the step size collapses before finding a good solution; too conservative and you waste iterations on large perturbations when you're already close.&lt;/p&gt;

&lt;h3&gt;
  
  
  When NOT to Use This Approach
&lt;/h3&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;High-dimensional parameter spaces.&lt;/strong&gt; A single perturbation in 1000 dimensions is unlikely to improve on the current best by chance. Population methods like &lt;a href="https://sesen.ai/blog/cross-entropy-method-evolution-style-rl" rel="noopener noreferrer"&gt;CEM&lt;/a&gt; or &lt;a href="https://sesen.ai/blog/genetic-algorithms-from-scratch" rel="noopener noreferrer"&gt;genetic algorithms&lt;/a&gt; scale better&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Multi-modal reward landscapes.&lt;/strong&gt; Our hill climber can only find the nearest peak. If the global optimum is separated by a valley, you'll never reach it without true SA's downhill acceptance&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;When you need guarantees.&lt;/strong&gt; SA is a heuristic. Even true SA only guarantees convergence to the global optimum with logarithmic cooling, which is impractically slow&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;When wall-clock time matters more than sample efficiency.&lt;/strong&gt; SA is inherently sequential. CEM's 200 evaluations per iteration can run in parallel, making it faster on multi-core hardware despite using 12x more episodes&lt;/li&gt;
&lt;/ol&gt;

&lt;h2&gt;
  
  
  Where This Comes From
&lt;/h2&gt;

&lt;p&gt;Simulated annealing was introduced independently by &lt;strong&gt;Scott Kirkpatrick, Daniel Gelatt, and Mario Vecchi&lt;/strong&gt; at IBM Research in their 1983 Science paper &lt;a href="https://doi.org/10.1126/science.220.4598.671" rel="noopener noreferrer"&gt;"Optimization by Simulated Annealing"&lt;/a&gt;, and by &lt;strong&gt;Vlasta Cerny&lt;/strong&gt; in 1985. The name comes from the metallurgical process of annealing: heating a metal and then slowly cooling it to reduce defects in its crystal structure.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Metallurgy Analogy
&lt;/h3&gt;

&lt;p&gt;When you heat metal, atoms vibrate wildly and can escape local energy minima. As the temperature drops, atoms settle into increasingly stable configurations. If you cool slowly enough, the metal reaches its lowest-energy crystal state (the global optimum). Cool too fast and you get a brittle, disordered structure (a local optimum).&lt;/p&gt;

&lt;p&gt;Kirkpatrick and colleagues mapped this physical process to combinatorial optimisation:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Metal atoms&lt;/strong&gt; become candidate solutions&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Energy&lt;/strong&gt; becomes the cost function&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Temperature&lt;/strong&gt; becomes a control parameter that governs randomness&lt;/li&gt;
&lt;/ul&gt;

&lt;h3&gt;
  
  
  The Metropolis Connection
&lt;/h3&gt;

&lt;p&gt;The acceptance criterion in true SA comes directly from the &lt;strong&gt;Metropolis algorithm&lt;/strong&gt; (Metropolis, Rosenbluth, Rosenbluth, Teller, and Teller, 1953), originally designed for simulating atomic systems in statistical mechanics. At temperature &lt;code&gt;$T$&lt;/code&gt;, a new state with energy &lt;code&gt;$E'$&lt;/code&gt; is accepted from a current state with energy &lt;code&gt;$E$&lt;/code&gt; with probability:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DP%28%255Ctext%257Baccept%257D%29%2520%253D%2520%255Cbegin%257Bcases%257D%25201%2520%2526%2520%255Ctext%257Bif%2520%257D%2520E%27%2520%253C%2520E%2520%255C%255C%2520e%255E%257B-%28E%27%2520-%2520E%29%2520%252F%2520T%257D%2520%2526%2520%255Ctext%257Bif%2520%257D%2520E%27%2520%255Cgeq%2520E%2520%255Cend%257Bcases%257D" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DP%28%255Ctext%257Baccept%257D%29%2520%253D%2520%255Cbegin%257Bcases%257D%25201%2520%2526%2520%255Ctext%257Bif%2520%257D%2520E%27%2520%253C%2520E%2520%255C%255C%2520e%255E%257B-%28E%27%2520-%2520E%29%2520%252F%2520T%257D%2520%2526%2520%255Ctext%257Bif%2520%257D%2520E%27%2520%255Cgeq%2520E%2520%255Cend%257Bcases%257D" alt="equation" width="390" height="75"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;At high &lt;code&gt;$T$&lt;/code&gt;, the exponential is close to 1, so almost any move is accepted (random exploration). At low &lt;code&gt;$T$&lt;/code&gt;, only improvements or tiny degradations are accepted (local refinement). As &lt;code&gt;$T \to 0$&lt;/code&gt;, the algorithm becomes pure hill climbing.&lt;/p&gt;

&lt;p&gt;This is the same acceptance probability we explored in the &lt;a href="https://sesen.ai/blog/mcmc-metropolis-hastings-island-hopping-guide" rel="noopener noreferrer"&gt;Metropolis-Hastings post&lt;/a&gt; for MCMC sampling. The only difference: in MCMC, we maintain a high temperature to sample broadly; in SA, we lower it to converge on a peak. Same mechanism, different goals.&lt;/p&gt;

&lt;h3&gt;
  
  
  Our Variant vs Classical SA
&lt;/h3&gt;

&lt;p&gt;Our implementation simplifies classical SA in two ways:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;No downhill acceptance.&lt;/strong&gt; We only accept improvements, making our algorithm a strict hill climber. Classical SA would occasionally accept a worse solution, with probability decreasing as the temperature drops&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Adaptive cooling.&lt;/strong&gt; Classical SA uses a fixed cooling schedule (e.g., &lt;code&gt;$T_k = T_0 / \log(1+k)$&lt;/code&gt; for the theoretical guarantee). Our schedule only cools when an improvement is found, which adapts the exploration rate to the difficulty of the landscape&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;Despite these simplifications, our algorithm captures SA's core idea: start with large moves (exploration) and gradually transition to small moves (exploitation). For low-dimensional problems like our 4-parameter CartPole policy, this simplified variant works as well as the full SA.&lt;/p&gt;

&lt;h3&gt;
  
  
  Theoretical Guarantees
&lt;/h3&gt;

&lt;p&gt;Kirkpatrick et al. proved that SA with logarithmic cooling (&lt;code&gt;$T_k = c / \log(1+k)$&lt;/code&gt;) converges to the global optimum in probability. However, this schedule is impractically slow for real problems. In practice, faster geometric schedules (&lt;code&gt;$T_{k+1} = \alpha T_k$&lt;/code&gt;) are used, sacrificing the global optimality guarantee for practical convergence speed.&lt;/p&gt;

&lt;blockquote&gt;
&lt;p&gt;"There is a deep and useful connection between statistical mechanics [...] and multivariate or combinatorial optimization. [...] We have applied this framework to the design of computer hardware, to a specific and practical problem in computer layout."&lt;br&gt;
&lt;em&gt;Kirkpatrick, Gelatt, and Vecchi (1983)&lt;/em&gt;&lt;/p&gt;
&lt;/blockquote&gt;

&lt;h3&gt;
  
  
  Further Reading
&lt;/h3&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://doi.org/10.1126/science.220.4598.671" rel="noopener noreferrer"&gt;Kirkpatrick, Gelatt, and Vecchi (1983)&lt;/a&gt;, "Optimization by Simulated Annealing" - The foundational paper. Read Section II for the algorithm and Section IV for the VLSI application&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://doi.org/10.1063/1.1699114" rel="noopener noreferrer"&gt;Metropolis et al. (1953)&lt;/a&gt;, "Equation of State Calculations by Fast Computing Machines" - The acceptance criterion used by SA, originally for molecular simulation&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://doi.org/10.1007/BF00940812" rel="noopener noreferrer"&gt;Cerny (1985)&lt;/a&gt;, "Thermodynamical Approach to the Traveling Salesman Problem" - Independent invention of SA for TSP&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Sutton and Barto (2018)&lt;/strong&gt;, Ch. 1 - Context for derivative-free methods in the RL landscape&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://doi.org/10.1007/978-1-4757-4321-0" rel="noopener noreferrer"&gt;Rubinstein and Kroese (2004)&lt;/a&gt;, &lt;em&gt;The Cross-Entropy Method&lt;/em&gt; - For comparison with the population-based approach&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Try It Yourself
&lt;/h2&gt;

&lt;p&gt;The &lt;a href="https://colab.research.google.com/github/zhubarb/sesen_ai_ml_tutorials/blob/main/notebooks/reinforcement-learning/simulated_annealing_cartpole.ipynb" rel="noopener noreferrer"&gt;interactive notebook&lt;/a&gt; includes exercises:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Decay rate sweep&lt;/strong&gt;: Try &lt;code&gt;decay&lt;/code&gt; values of 0.8, 0.9, 0.95, and 0.99. How does the cooling speed affect convergence? Is there a sweet spot?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;True simulated annealing&lt;/strong&gt;: Modify the algorithm to accept worse solutions with probability &lt;code&gt;$e^{-\Delta / T}$&lt;/code&gt; where &lt;code&gt;$\Delta$&lt;/code&gt; is the score difference and &lt;code&gt;$T$&lt;/code&gt; decays on a fixed schedule. Does it help on CartPole? When would it matter?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Seed sensitivity&lt;/strong&gt;: Run the algorithm 20 times with different random seeds. What fraction of runs reach 500? How does this compare to CEM's reliability?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Harder environments&lt;/strong&gt;: Try SA on &lt;code&gt;Acrobot-v1&lt;/code&gt; or &lt;code&gt;MountainCar-v0&lt;/code&gt;. Does the 4-parameter linear policy have enough capacity, or do these environments need a richer policy class?&lt;/li&gt;
&lt;/ol&gt;

&lt;h2&gt;
  
  
  Interactive Tools
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/q-learning-visualizer" rel="noopener noreferrer"&gt;Q-Learning Visualiser&lt;/a&gt; — Compare SA's derivative-free approach with value-based RL on grid worlds&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Related Posts
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/cross-entropy-method-evolution-style-rl" rel="noopener noreferrer"&gt;The Cross-Entropy Method: Solving RL Without Gradients&lt;/a&gt; - The population-based companion to SA. Both are derivative-free, but CEM trades sample efficiency for robustness by maintaining 200 candidates per iteration.&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/mcmc-metropolis-hastings-island-hopping-guide" rel="noopener noreferrer"&gt;MCMC Island Hopping: Understanding Metropolis-Hastings&lt;/a&gt; - The acceptance criterion that powers true SA comes directly from the Metropolis algorithm. In MCMC we sample from a distribution; in SA we find its peak.&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/genetic-algorithms-from-scratch" rel="noopener noreferrer"&gt;Genetic Algorithms: From Line Fitting to the Travelling Salesman&lt;/a&gt; - Another derivative-free optimisation family. GAs use crossover and mutation on a population; SA uses perturbation on a single solution.&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Frequently Asked Questions
&lt;/h2&gt;

&lt;h3&gt;
  
  
  How is simulated annealing different from random search?
&lt;/h3&gt;

&lt;p&gt;Random search samples a completely new policy each iteration and tracks the best one found, but it can never refine a promising solution. Simulated annealing builds on previous improvements by perturbing the current best parameters with decreasing noise. This ability to make small refinements to an already-good solution is what pushes SA from "decent" to "perfect" on CartPole.&lt;/p&gt;

&lt;h3&gt;
  
  
  Why does the algorithm evaluate each candidate over 10 episodes instead of 1?
&lt;/h3&gt;

&lt;p&gt;CartPole has stochastic initial conditions, so a single episode can be misleading. A policy might score 500 on one lucky initialisation and 50 on the next. Averaging over 10 episodes gives a stable estimate of true quality, preventing the algorithm from accepting a lucky fluke or rejecting a good policy due to bad luck.&lt;/p&gt;

&lt;h3&gt;
  
  
  Is this true simulated annealing?
&lt;/h3&gt;

&lt;p&gt;Not quite. True simulated annealing occasionally accepts worse solutions with a probability that decreases over time, allowing it to escape local optima. Our implementation is a strict hill climber that only accepts improvements. The "annealing" part refers only to the shrinking step size. For CartPole's smooth 4-parameter landscape, this distinction does not matter, but for problems with many local optima, true SA's downhill acceptance becomes essential.&lt;/p&gt;

&lt;h3&gt;
  
  
  Why does the step size only shrink when an improvement is found?
&lt;/h3&gt;

&lt;p&gt;This creates an adaptive cooling schedule. If the algorithm keeps finding improvements, the step size decays quickly, focusing the search around the current best. If it gets stuck in a plateau, the step size stays large, maintaining broad exploration. This naturally spends more time exploring when stuck and more time refining when making progress.&lt;/p&gt;

&lt;h3&gt;
  
  
  When would simulated annealing fail compared to population-based methods?
&lt;/h3&gt;

&lt;p&gt;SA struggles in high-dimensional parameter spaces where a single random perturbation is unlikely to improve all parameters at once. It also fails on multi-modal reward landscapes because, as a strict hill climber, it can only find the nearest peak. Population-based methods like the Cross-Entropy Method or genetic algorithms handle both cases better by maintaining diversity across many candidates simultaneously.&lt;/p&gt;

</description>
      <category>reinforcementlearning</category>
      <category>optimisation</category>
    </item>
    <item>
      <title>The Cross-Entropy Method: Solving RL Without Gradients</title>
      <dc:creator>Berkan Sesen</dc:creator>
      <pubDate>Tue, 21 Apr 2026 08:27:46 +0000</pubDate>
      <link>https://dev.to/berkan_sesen/the-cross-entropy-method-solving-rl-without-gradients-1lol</link>
      <guid>https://dev.to/berkan_sesen/the-cross-entropy-method-solving-rl-without-gradients-1lol</guid>
      <description>&lt;p&gt;Reinforcement learning has accumulated layers of complexity over the years: value functions, policy gradients, replay buffers, target networks. The Cross-Entropy Method predates all of it. Rubinstein introduced it in 1997 for rare-event simulation, and it turned out to solve simple control tasks with almost no machinery. The entire implementation fits in 50 lines. No gradients, no training loops. Just: sample some parameters, test them, keep the best ones, repeat.&lt;/p&gt;

&lt;p&gt;The Cross-Entropy Method (CEM) is the algorithm you reach for when you want results without complexity. It treats the policy's parameters as a black box, maintains a probability distribution over them, and iteratively narrows that distribution toward high-performing regions. No gradients required. By the end of this post, you'll implement CEM from scratch, solve CartPole-v1 with a perfect score, and understand why this "naive" approach works so well on problems with manageable parameter spaces.&lt;/p&gt;

&lt;h2&gt;
  
  
  Let's Build It
&lt;/h2&gt;

&lt;p&gt;Click the badge to open the interactive notebook:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://colab.research.google.com/github/zhubarb/sesen_ai_ml_tutorials/blob/main/notebooks/reinforcement-learning/cross_entropy_method.ipynb" rel="noopener noreferrer"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fcolab.research.google.com%2Fassets%2Fcolab-badge.svg" alt="Open In Colab" width="117" height="20"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fh9lzj24j8fzzyzxjhagc.gif" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fh9lzj24j8fzzyzxjhagc.gif" alt="CEM convergence animation showing the reward distribution shifting from low to high over 50 iterations" width="800" height="500"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Here's the complete implementation. We use a linear policy with just 4 parameters (one per observation dimension), and CEM finds the perfect weights:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;gymnasium&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;evaluate_policy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env_name&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;theta&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;Run one episode with a linear policy: action = 1 if theta @ obs &amp;gt; 0 else 0.&lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
    &lt;span class="n"&gt;env&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;make&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env_name&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;reset&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="n"&gt;total_reward&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
    &lt;span class="n"&gt;done&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;False&lt;/span&gt;
    &lt;span class="k"&gt;while&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="n"&gt;done&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="n"&gt;action&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;dot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;theta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;&amp;gt;&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
        &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;total_reward&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;
        &lt;span class="n"&gt;done&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt; &lt;span class="ow"&gt;or&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;
    &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;close&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;total_reward&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;cem&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env_name&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_params&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;200&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_iter&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;50&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;elite_frac&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;initial_std&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;extra_std&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;std_decay_time&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;25&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;Cross-Entropy Method for policy search.&lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
    &lt;span class="n"&gt;n_elite&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;int&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;round&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;elite_frac&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
    &lt;span class="n"&gt;th_mean&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_params&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;th_std&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;ones&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_params&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;initial_std&lt;/span&gt;

    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;iteration&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_iter&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="c1"&gt;# Decaying extra noise (Szita &amp;amp; Lörincz 2006)
&lt;/span&gt;        &lt;span class="n"&gt;noise_multiplier&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;iteration&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="nf"&gt;float&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;std_decay_time&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;sample_std&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;th_std&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;square&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;extra_std&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;noise_multiplier&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="c1"&gt;# Sample and evaluate
&lt;/span&gt;        &lt;span class="n"&gt;thetas&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;th_mean&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;sample_std&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_params&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;rewards&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;array&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="nf"&gt;evaluate_policy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env_name&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;th&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;th&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;thetas&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;

        &lt;span class="c1"&gt;# Select elite and refit distribution
&lt;/span&gt;        &lt;span class="n"&gt;elite_inds&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;rewards&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;argsort&lt;/span&gt;&lt;span class="p"&gt;()[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;n_elite&lt;/span&gt;&lt;span class="p"&gt;:]&lt;/span&gt;
        &lt;span class="n"&gt;elite_thetas&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;thetas&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;elite_inds&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
        &lt;span class="n"&gt;th_mean&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;elite_thetas&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;th_std&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;elite_thetas&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;var&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Iter &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;iteration&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="n"&gt;d&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; | Mean: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;rewards&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="mf"&gt;6.1&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; | Max: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;rewards&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;max&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;th_mean&lt;/span&gt;

&lt;span class="n"&gt;best_theta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;cem&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;CartPole-v1&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_params&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="c1"&gt;# Iter   1 | Mean:   66.8 | Max: 500
# Iter  10 | Mean:  384.0 | Max: 500
# Iter  30 | Mean:  495.2 | Max: 500
# Iter  50 | Mean:  499.1 | Max: 500
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The population mean reward climbs from 67 to 499 in 50 iterations. Every single sample in the final batch scores near-perfect. Let's verify with 100 evaluation episodes:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;scores&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="nf"&gt;evaluate_policy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;CartPole-v1&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;best_theta&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;)]&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Mean: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;scores&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; ± &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;std&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;scores&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="c1"&gt;# Mean: 500 ± 0
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Perfect score. Four parameters, zero gradients, 50 iterations.&lt;/p&gt;

&lt;h2&gt;
  
  
  What Just Happened?
&lt;/h2&gt;

&lt;p&gt;CEM works by maintaining a Gaussian distribution over policy parameters and repeatedly narrowing it toward the best-performing region. Each iteration has three steps:&lt;/p&gt;

&lt;h3&gt;
  
  
  Step 1: Sample
&lt;/h3&gt;

&lt;p&gt;We draw &lt;code&gt;batch_size=200&lt;/code&gt; parameter vectors from a Gaussian:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;thetas&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;th_mean&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;sample_std&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_params&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Each &lt;code&gt;theta&lt;/code&gt; is a candidate policy. In iteration 1, the mean is zeros and the standard deviation is 1.0, so we're sampling random policies.&lt;/p&gt;

&lt;h3&gt;
  
  
  Step 2: Evaluate and Select
&lt;/h3&gt;

&lt;p&gt;We run each candidate policy on CartPole and rank them by total reward. Then we keep only the top 20% (the "elite" set):&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;rewards&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;array&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="nf"&gt;evaluate_policy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env_name&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;th&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;th&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;thetas&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;span class="n"&gt;elite_inds&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;rewards&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;argsort&lt;/span&gt;&lt;span class="p"&gt;()[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;n_elite&lt;/span&gt;&lt;span class="p"&gt;:]&lt;/span&gt;  &lt;span class="c1"&gt;# Top 40 out of 200
&lt;/span&gt;&lt;span class="n"&gt;elite_thetas&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;thetas&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;elite_inds&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  Step 3: Refit the Distribution
&lt;/h3&gt;

&lt;p&gt;We refit the Gaussian to match the elite samples:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;th_mean&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;elite_thetas&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;th_std&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;elite_thetas&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;var&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The new mean moves toward parameters that performed well. The new variance shrinks because the elite samples cluster together. Next iteration, we sample from this tighter distribution, generating better candidates on average.&lt;/p&gt;

&lt;h3&gt;
  
  
  Watching It Converge
&lt;/h3&gt;

&lt;p&gt;The training curve shows how the population improves:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;subplots&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;iters&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mean_rewards&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;b-&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.7&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Population mean&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;iters&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;elite_mean_rewards&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;r-&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;linewidth&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Elite mean&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;iters&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_rewards&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;g--&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Best in batch&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;axhline&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;500&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;k&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;linestyle&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;:&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Max possible (500)&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_xlabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Iteration&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_ylabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Total Reward&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;legend&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fjsrtq0yzp11jtcxexvvj.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fjsrtq0yzp11jtcxexvvj.webp" alt="CEM training curve on CartPole-v1 showing population mean climbing from 67 to 500 over 50 iterations" width="800" height="450"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The elite mean hits 500 almost immediately (iteration 2). But the population mean takes longer to catch up because the distribution is still wide. By iteration 30, even randomly sampled policies from the learned distribution score near-perfect.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Distribution Narrows Over Time
&lt;/h3&gt;

&lt;p&gt;To see this visually, here's how the reward distribution across the 200 samples evolves:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axes&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;subplots&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;12&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;iteration_rewards&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;title&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;zip&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;axes&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;selected_iterations&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;titles&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;hist&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;iteration_rewards&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;bins&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;20&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;steelblue&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;edgecolor&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;white&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;axvline&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;iteration_rewards&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;red&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;linestyle&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;--&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F7n9zgznzl5aqo9ix7wa5.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F7n9zgznzl5aqo9ix7wa5.webp" alt="Reward distributions at iterations 1, 10, and 50 showing the population concentrating at 500" width="800" height="261"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;In iteration 1, most policies fail quickly (reward &amp;lt; 100) with a few lucky ones reaching 500. By iteration 10, the distribution is bimodal: many policies near 500 but some still struggling. By iteration 50, the entire population clusters at 500. The distribution has collapsed onto the solution.&lt;/p&gt;

&lt;h2&gt;
  
  
  Going Deeper
&lt;/h2&gt;

&lt;h3&gt;
  
  
  The Noisy Cross-Entropy Method
&lt;/h3&gt;

&lt;p&gt;The original CEM (Rubinstein 1999) has a failure mode: the variance can collapse to zero too quickly, trapping the search in a local optimum. Szita and Lörincz (2006) fixed this with the "noisy" variant that adds decaying extra variance:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Csigma_%257Bt%252B1%257D%255E2%2520%253D%2520%255Csigma_%257Bt%252C%255Ctext%257Belite%257D%257D%255E2%2520%252B%2520Z_t%255E2%2520%255Ccdot%2520%255Csigma_%257B%255Ctext%257Bextra%257D%257D%255E2" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Csigma_%257Bt%252B1%257D%255E2%2520%253D%2520%255Csigma_%257Bt%252C%255Ctext%257Belite%257D%257D%255E2%2520%252B%2520Z_t%255E2%2520%255Ccdot%2520%255Csigma_%257B%255Ctext%257Bextra%257D%257D%255E2" alt="equation" width="265" height="31"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;where &lt;code&gt;$Z_t = \max(1 - t / T_{\text{decay}},\; 0)$&lt;/code&gt; decays linearly to zero. Early iterations get extra exploration; later iterations trust the elite variance.&lt;/p&gt;

&lt;p&gt;This is exactly what our code does:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;noise_multiplier&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;iteration&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="nf"&gt;float&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;std_decay_time&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;sample_std&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;th_std&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;square&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;extra_std&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;noise_multiplier&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The &lt;code&gt;extra_std=0.5&lt;/code&gt; decays over &lt;code&gt;std_decay_time=25&lt;/code&gt; iterations. After iteration 25, the sampling distribution uses only the elite variance.&lt;/p&gt;

&lt;h3&gt;
  
  
  Hyperparameters
&lt;/h3&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Parameter&lt;/th&gt;
&lt;th&gt;Value&lt;/th&gt;
&lt;th&gt;Effect&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;batch_size&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;200&lt;/td&gt;
&lt;td&gt;More samples = better coverage but slower per iteration&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;elite_frac&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;0.2&lt;/td&gt;
&lt;td&gt;Lower = more selective, faster convergence, risk of premature collapse&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;initial_std&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;1.0&lt;/td&gt;
&lt;td&gt;Too low = miss good regions; too high = waste samples on extreme policies&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;extra_std&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;0.5&lt;/td&gt;
&lt;td&gt;Noise injection; 0 = original CEM, &amp;gt;0 = noisy CEM&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;std_decay_time&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;25&lt;/td&gt;
&lt;td&gt;How many iterations before extra noise disappears&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;p&gt;The most sensitive parameter is &lt;code&gt;elite_frac&lt;/code&gt;. At 0.2 (keep top 40 of 200), we balance exploitation and exploration. Setting it to 0.01 (keep top 2) would converge faster in easy environments but collapse in hard ones.&lt;/p&gt;

&lt;h3&gt;
  
  
  CEM vs Random Search
&lt;/h3&gt;

&lt;p&gt;Both CEM and random search sample 200 policies per iteration. The difference: random search starts fresh every time, while CEM builds on what worked:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;subplots&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;iters&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cem_mean_rewards&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;b-&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;linewidth&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;CEM (population mean)&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;iters&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_mean_rewards&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;r--&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;linewidth&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Random search (mean)&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fkw1tl8s1h1urkxx3ka5d.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fkw1tl8s1h1urkxx3ka5d.webp" alt="CEM population mean climbing to 500 while random search stays flat at ~60" width="800" height="450"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Random search averages about 60 reward per iteration, forever. CEM reaches 500 because each iteration's distribution is informed by the last. The "select and refit" loop creates a directed search through parameter space.&lt;/p&gt;

&lt;h3&gt;
  
  
  CEM vs Policy Gradients vs DQN
&lt;/h3&gt;

&lt;p&gt;How does CEM compare to the gradient-based methods we've covered?&lt;/p&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Method&lt;/th&gt;
&lt;th&gt;What it optimises&lt;/th&gt;
&lt;th&gt;Needs gradients?&lt;/th&gt;
&lt;th&gt;Scales to large nets?&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;CEM&lt;/td&gt;
&lt;td&gt;Policy parameters directly&lt;/td&gt;
&lt;td&gt;No&lt;/td&gt;
&lt;td&gt;Poorly (&amp;gt;1000 params)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;a href="https://sesen.ai/blog/policy-gradients-reinforce-from-scratch" rel="noopener noreferrer"&gt;REINFORCE&lt;/a&gt;&lt;/td&gt;
&lt;td&gt;Policy parameters via log-prob gradient&lt;/td&gt;
&lt;td&gt;Yes&lt;/td&gt;
&lt;td&gt;Yes&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;a href="https://sesen.ai/blog/deep-q-networks-experience-replay-target-networks" rel="noopener noreferrer"&gt;DQN&lt;/a&gt;&lt;/td&gt;
&lt;td&gt;Value function (Q-values)&lt;/td&gt;
&lt;td&gt;Yes&lt;/td&gt;
&lt;td&gt;Yes&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;a href="https://sesen.ai/blog/q-learning-frozen-lake-from-scratch" rel="noopener noreferrer"&gt;Q-Learning&lt;/a&gt;&lt;/td&gt;
&lt;td&gt;Value function (Q-table)&lt;/td&gt;
&lt;td&gt;No&lt;/td&gt;
&lt;td&gt;No (tabular only)&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;p&gt;CEM's sweet spot: &lt;strong&gt;problems with fewer than ~1000 parameters&lt;/strong&gt; where you want a simple, parallelisable algorithm. For a 4-parameter linear policy on CartPole, CEM is hard to beat. For a million-parameter Atari network, you need &lt;a href="https://sesen.ai/blog/policy-gradients-reinforce-from-scratch" rel="noopener noreferrer"&gt;policy gradients&lt;/a&gt; or DQN.&lt;/p&gt;

&lt;h3&gt;
  
  
  When NOT to Use CEM
&lt;/h3&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;High-dimensional parameter spaces.&lt;/strong&gt; CEM samples grow exponentially less effective as dimensions increase. A 1000-parameter network needs enormous batch sizes&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Environments with sparse rewards.&lt;/strong&gt; If most policies score zero (e.g., Montezuma's Revenge), the elite set is just noise&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;When you need sample efficiency.&lt;/strong&gt; CEM uses 200 episodes per iteration vs REINFORCE using ~5 episodes per batch. If environment evaluations are expensive, gradient methods win&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Continuous action spaces with complex dynamics.&lt;/strong&gt; CEM with a linear policy can only learn linear decision boundaries. Problems requiring nonlinear policies need either a neural network (large parameter space) or a different algorithm&lt;/li&gt;
&lt;/ol&gt;

&lt;h3&gt;
  
  
  Connection to Genetic Algorithms
&lt;/h3&gt;

&lt;p&gt;If you read the &lt;a href="https://sesen.ai/blog/genetic-algorithms-from-scratch" rel="noopener noreferrer"&gt;genetic algorithms post&lt;/a&gt;, CEM will feel familiar. Both are population-based, derivative-free optimisation methods. The difference is in how they generate the next population:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Genetic algorithms&lt;/strong&gt; use crossover and mutation operators on individual solutions&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;CEM&lt;/strong&gt; fits a probability distribution to the elite set and samples from it&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;CEM is sometimes called an "estimation of distribution algorithm" (EDA). Instead of recombining individual solutions, it models the structure of good solutions as a distribution and samples new candidates from that model. For real-valued parameter optimisation, this Gaussian model is often more effective than genetic crossover.&lt;/p&gt;

&lt;h2&gt;
  
  
  Where This Comes From
&lt;/h2&gt;

&lt;p&gt;The Cross-Entropy Method was introduced by &lt;strong&gt;Reuven Rubinstein&lt;/strong&gt; in his 1999 paper &lt;a href="https://doi.org/10.1023/A:1010091220143" rel="noopener noreferrer"&gt;"The Cross-Entropy Method for Combinatorial and Continuous Optimization"&lt;/a&gt;. The name comes from the original application: minimising the cross-entropy (KL divergence) between a reference distribution and the optimal importance sampling distribution for rare-event simulation.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Core Idea
&lt;/h3&gt;

&lt;p&gt;Rubinstein's insight was that rare-event estimation and optimisation are essentially the same problem. To estimate &lt;code&gt;$P(S(X) \geq \gamma)$&lt;/code&gt; for a rare threshold &lt;code&gt;$\gamma$&lt;/code&gt;, you need to find a sampling distribution that concentrates on high-&lt;code&gt;$S(X)$&lt;/code&gt; regions. The CE method does this by iteratively:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;Drawing samples from the current distribution &lt;code&gt;$f(\cdot;\, v_t)$&lt;/code&gt;
&lt;/li&gt;
&lt;li&gt;Selecting the elite samples (those with &lt;code&gt;$S(X) \geq \gamma_t$&lt;/code&gt;)&lt;/li&gt;
&lt;li&gt;Updating the distribution parameters to minimise the KL divergence to the empirical elite distribution&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;For a Gaussian family, step 3 has a closed-form solution: the mean and variance of the elite samples. This is exactly what our implementation does.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Formal Algorithm
&lt;/h3&gt;

&lt;p&gt;From Rubinstein and Kroese (2004), the CEM update for a parametric family &lt;code&gt;$\{f(\cdot;\, v)\}$&lt;/code&gt; is:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Dv_%257Bt%252B1%257D%2520%253D%2520%255Carg%255Cmax_v%2520%255Cfrac%257B1%257D%257BN%257D%2520%255Csum_%257Bi%253D1%257D%255E%257BN%257D%2520I%255C%257BS%28X_i%29%2520%255Cgeq%2520%255Cgamma_t%255C%257D%2520%255Cln%2520f%28X_i%253B%255C%252C%2520v%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Dv_%257Bt%252B1%257D%2520%253D%2520%255Carg%255Cmax_v%2520%255Cfrac%257B1%257D%257BN%257D%2520%255Csum_%257Bi%253D1%257D%255E%257BN%257D%2520I%255C%257BS%28X_i%29%2520%255Cgeq%2520%255Cgamma_t%255C%257D%2520%255Cln%2520f%28X_i%253B%255C%252C%2520v%29" alt="equation" width="504" height="72"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;where &lt;code&gt;$I\{\cdot\}$&lt;/code&gt; is the indicator function selecting elite samples. For a multivariate Gaussian with diagonal covariance, this yields:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Cmu_%257Bt%252B1%257D%2520%253D%2520%255Cfrac%257B1%257D%257BN_%257B%255Ctext%257Belite%257D%257D%257D%2520%255Csum_%257Bi%2520%255Cin%2520%255Ctext%257Belite%257D%257D%2520X_i%252C%2520%255Cquad%2520%255Csigma_%257Bt%252B1%257D%255E2%2520%253D%2520%255Cfrac%257B1%257D%257BN_%257B%255Ctext%257Belite%257D%257D%257D%2520%255Csum_%257Bi%2520%255Cin%2520%255Ctext%257Belite%257D%257D%2520%28X_i%2520-%2520%255Cmu_%257Bt%252B1%257D%29%255E2" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Cmu_%257Bt%252B1%257D%2520%253D%2520%255Cfrac%257B1%257D%257BN_%257B%255Ctext%257Belite%257D%257D%257D%2520%255Csum_%257Bi%2520%255Cin%2520%255Ctext%257Belite%257D%257D%2520X_i%252C%2520%255Cquad%2520%255Csigma_%257Bt%252B1%257D%255E2%2520%253D%2520%255Cfrac%257B1%257D%257BN_%257B%255Ctext%257Belite%257D%257D%257D%2520%255Csum_%257Bi%2520%255Cin%2520%255Ctext%257Belite%257D%257D%2520%28X_i%2520-%2520%255Cmu_%257Bt%252B1%257D%29%255E2" alt="equation" width="573" height="64"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The sample mean and variance of the elite set. Elegantly simple.&lt;/p&gt;

&lt;h3&gt;
  
  
  From Rare Events to Tetris
&lt;/h3&gt;

&lt;p&gt;The method found its way into reinforcement learning through &lt;strong&gt;Szita and Lörincz (2006)&lt;/strong&gt;, &lt;a href="https://doi.org/10.1162/neco.2006.18.12.2936" rel="noopener noreferrer"&gt;"Learning Tetris Using the Noisy Cross-Entropy Method"&lt;/a&gt;. They made two key modifications for RL:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Noisy updates&lt;/strong&gt;: Adding decaying extra variance to prevent premature convergence (the &lt;code&gt;extra_std&lt;/code&gt; parameter in our code)&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Direct policy search&lt;/strong&gt;: Treating the policy's weight vector as the parameter to optimise, with episode return as the objective function&lt;/li&gt;
&lt;/ol&gt;

&lt;blockquote&gt;
&lt;p&gt;"The noisy cross-entropy method adds a time-decreasing noise term to avoid premature convergence of the variance to zero."&lt;br&gt;
&lt;em&gt;Szita and Lörincz (2006)&lt;/em&gt;&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;Their noisy CEM achieved record-breaking performance on Tetris at the time, outperforming methods that required orders of magnitude more computation. Our implementation follows their variant faithfully, including the linear noise decay schedule described in Section 3 of their paper.&lt;/p&gt;

&lt;h3&gt;
  
  
  Theoretical Properties
&lt;/h3&gt;

&lt;p&gt;Unlike policy gradient methods, CEM has no convergence guarantees to a local optimum. It is a heuristic. However, it has practical advantages:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Embarrassingly parallel&lt;/strong&gt;: All 200 evaluations per iteration are independent&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;No reward shaping needed&lt;/strong&gt;: Works with any scalar objective, even non-differentiable ones&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Robust to noisy evaluations&lt;/strong&gt;: The elite selection acts as a natural filter&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;The method's simplicity is also its limitation. As Rubinstein and Kroese note, the Gaussian parametric family assumes the optimal parameter region is unimodal. Multi-modal reward landscapes can trap CEM in a single mode.&lt;/p&gt;

&lt;h3&gt;
  
  
  Further Reading
&lt;/h3&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://doi.org/10.1023/A:1010091220143" rel="noopener noreferrer"&gt;Rubinstein (1999)&lt;/a&gt;, "The Cross-Entropy Method for Combinatorial and Continuous Optimization" - The original CE method paper&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://doi.org/10.1007/978-1-4757-4321-0" rel="noopener noreferrer"&gt;Rubinstein and Kroese (2004)&lt;/a&gt;, &lt;em&gt;The Cross-Entropy Method: A Unified Approach to Combinatorial Optimization, Monte Carlo Simulation, and Machine Learning&lt;/em&gt; - The comprehensive textbook&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://doi.org/10.1162/neco.2006.18.12.2936" rel="noopener noreferrer"&gt;Szita and Lörincz (2006)&lt;/a&gt;, "Learning Tetris Using the Noisy Cross-Entropy Method" - The noisy variant for RL&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://arxiv.org/abs/1703.03864" rel="noopener noreferrer"&gt;Salimans et al. (2017)&lt;/a&gt;, "Evolution Strategies as a Scalable Alternative to Reinforcement Learning" - Modern evolution strategies at scale (OpenAI)&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Sutton and Barto (2018)&lt;/strong&gt;, Ch. 13 - Policy gradient methods for comparison&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Try It Yourself
&lt;/h2&gt;

&lt;p&gt;The &lt;a href="https://colab.research.google.com/github/zhubarb/sesen_ai_ml_tutorials/blob/main/notebooks/reinforcement-learning/cross_entropy_method.ipynb" rel="noopener noreferrer"&gt;interactive notebook&lt;/a&gt; includes exercises:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Elite fraction sweep&lt;/strong&gt;: Try &lt;code&gt;elite_frac&lt;/code&gt; values of 0.01, 0.1, 0.2, and 0.5. How does selectivity affect convergence speed and stability?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Noisy vs vanilla CEM&lt;/strong&gt;: Set &lt;code&gt;extra_std=0&lt;/code&gt; and compare convergence. Does the noisy variant help on CartPole, or only on harder problems?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Neural network policy&lt;/strong&gt;: Replace the linear policy with a small neural net (8 hidden units). How many CEM iterations does it take to solve CartPole now? At what network size does CEM become impractical?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Different environments&lt;/strong&gt;: Try CEM on &lt;code&gt;Acrobot-v1&lt;/code&gt; or &lt;code&gt;MountainCar-v0&lt;/code&gt;. Which environments does CEM handle well, and which expose its limitations?&lt;/li&gt;
&lt;/ol&gt;

&lt;h2&gt;
  
  
  Interactive Tools
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/q-learning-visualizer" rel="noopener noreferrer"&gt;Q-Learning Visualiser&lt;/a&gt; — See value-based RL in action and compare it with the policy search approach of CEM&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Related Posts
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/genetic-algorithms-from-scratch" rel="noopener noreferrer"&gt;Genetic Algorithms: From Line Fitting to the Travelling Salesman&lt;/a&gt; - Another population-based, derivative-free optimisation method. CEM replaces crossover and mutation with distribution fitting.&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/policy-gradients-reinforce-from-scratch" rel="noopener noreferrer"&gt;Policy Gradients: REINFORCE from Scratch with NumPy&lt;/a&gt; - The gradient-based alternative for policy search. Uses backpropagation through the policy, which scales to large networks but requires differentiable objectives.&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/deep-q-networks-experience-replay-target-networks" rel="noopener noreferrer"&gt;Deep Q-Networks: Experience Replay and Target Networks&lt;/a&gt; - Value-based RL with neural networks. A fundamentally different approach that learns what states are valuable rather than directly searching for good policies.&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Frequently Asked Questions
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Why is the Cross-Entropy Method called "cross-entropy" if it does not use a loss function?
&lt;/h3&gt;

&lt;p&gt;The name comes from the original application in rare-event simulation, where the algorithm minimises the cross-entropy (KL divergence) between the current sampling distribution and the optimal importance sampling distribution. In the reinforcement learning context, the name persists even though the update reduces to simply computing the mean and variance of the elite samples.&lt;/p&gt;

&lt;h3&gt;
  
  
  How does CEM compare to random search?
&lt;/h3&gt;

&lt;p&gt;Both methods sample candidate policies each iteration, but random search draws from a fixed distribution every time, while CEM updates its distribution based on the best-performing candidates. This directed search means CEM builds on previous successes, converging to good solutions far faster than random search on problems with structure to exploit.&lt;/p&gt;

&lt;h3&gt;
  
  
  Can CEM solve problems with continuous action spaces?
&lt;/h3&gt;

&lt;p&gt;CEM can optimise over continuous policy parameters, but the policy itself determines how actions are generated. A linear policy with CEM-optimised weights can only produce binary or discrete decisions. For truly continuous action spaces with complex dynamics, you would need a more expressive policy architecture, which increases the parameter count and makes CEM less practical.&lt;/p&gt;

&lt;h3&gt;
  
  
  What is the role of the elite fraction hyperparameter?
&lt;/h3&gt;

&lt;p&gt;The elite fraction controls how selective the algorithm is when choosing which candidates inform the next distribution. A smaller fraction (e.g. 0.01) converges faster but risks collapsing onto a local optimum. A larger fraction (e.g. 0.5) explores more broadly but converges more slowly. A value around 0.2 is a common default that balances exploitation and exploration.&lt;/p&gt;

&lt;h3&gt;
  
  
  Why does the noisy CEM variant add extra variance that decays over time?
&lt;/h3&gt;

&lt;p&gt;Without extra variance, the sampling distribution can collapse to near-zero variance too quickly, trapping the search around a potentially suboptimal solution. The decaying noise keeps exploration alive in early iterations when the algorithm is still uncertain about the best region, then gradually disappears to allow precise convergence in later iterations.&lt;/p&gt;

</description>
      <category>reinforcementlearning</category>
      <category>optimisation</category>
    </item>
    <item>
      <title>PCR vs PLS: When Fewer Features Beat More</title>
      <dc:creator>Berkan Sesen</dc:creator>
      <pubDate>Sun, 19 Apr 2026 15:38:56 +0000</pubDate>
      <link>https://dev.to/berkan_sesen/pcr-vs-pls-when-fewer-features-beat-more-2plp</link>
      <guid>https://dev.to/berkan_sesen/pcr-vs-pls-when-fewer-features-beat-more-2plp</guid>
      <description>&lt;p&gt;How much should a baseball team pay its players? The 1986 Major League season gives us 263 hitters with 19 statistics each: at-bats, hits, home runs, years played, and more. Predicting salary from performance sounds like a textbook regression problem, but 19 correlated features make it anything but. Throw them all into a &lt;a href="https://sesen.ai/blog/linear-regression-five-ways" rel="noopener noreferrer"&gt;linear regression&lt;/a&gt; and the model fits the training data beautifully but falls apart on held-out players. The coefficient estimates are wildly unstable, and salary predictions swing by thousands on minor input changes.&lt;/p&gt;

&lt;p&gt;The fix is not a fancier model. It is &lt;em&gt;fewer features&lt;/em&gt;, chosen more carefully. This post covers two classic strategies for doing exactly that: Principal Component Regression (PCR) and Partial Least Squares (PLS).&lt;/p&gt;

&lt;p&gt;By the end, you'll understand how both methods compress correlated features into a handful of components, why PLS typically needs fewer components than PCR, and when each approach is the right tool.&lt;/p&gt;

&lt;h2&gt;
  
  
  Quick Win: Predict Salaries with 6 Features Instead of 19
&lt;/h2&gt;

&lt;p&gt;Click the badge to open the interactive notebook:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://colab.research.google.com/github/zhubarb/sesen_ai_ml_tutorials/blob/main/notebooks/supervised/pcr_vs_pls.ipynb" rel="noopener noreferrer"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fcolab.research.google.com%2Fassets%2Fcolab-badge.svg" alt="Open In Colab" width="117" height="20"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;We'll use the classic ISLR Hitters dataset: 263 baseball players with 19 features (at-bats, hits, home runs, years played, etc.) predicting salary in thousands of dollars.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;pandas&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;pd&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn.preprocessing&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn.decomposition&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;PCA&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn.linear_model&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;LinearRegression&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn.cross_decomposition&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;PLSRegression&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn.model_selection&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;KFold&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cross_val_score&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;train_test_split&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn.metrics&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;mean_squared_error&lt;/span&gt;

&lt;span class="c1"&gt;# Load and prepare the Hitters dataset
&lt;/span&gt;&lt;span class="n"&gt;df&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pd&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;read_csv&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;https://raw.githubusercontent.com/selva86/datasets/master/Hitters.csv&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;
&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;dropna&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;dummies&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pd&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;get_dummies&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;df&lt;/span&gt;&lt;span class="p"&gt;[[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;League&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Division&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;NewLeague&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]])&lt;/span&gt;
&lt;span class="n"&gt;y&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;df&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Salary&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;span class="n"&gt;X&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pd&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;concat&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;
    &lt;span class="n"&gt;df&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;drop&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Salary&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;League&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Division&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;NewLeague&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;astype&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;float64&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
    &lt;span class="n"&gt;dummies&lt;/span&gt;&lt;span class="p"&gt;[[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;League_N&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Division_W&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;NewLeague_N&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]]&lt;/span&gt;
&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Train/test split
&lt;/span&gt;&lt;span class="n"&gt;X_train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;X_test&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_test&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;train_test_split&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;test_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# PCR: PCA on scaled training data, then regression
&lt;/span&gt;&lt;span class="n"&gt;pca&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;PCA&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;X_train_pc&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pca&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fit_transform&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;scale&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_train&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;X_test_pc&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pca&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;transform&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;scale&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_test&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

&lt;span class="c1"&gt;# Use 10-fold CV to find the best number of components
&lt;/span&gt;&lt;span class="n"&gt;kf&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;KFold&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_splits&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shuffle&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;regr&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;LinearRegression&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;mse_by_ncomp&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;score&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="nf"&gt;cross_val_score&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
        &lt;span class="n"&gt;regr&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;X_train_pc&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;y_train&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;to_numpy&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt;
        &lt;span class="n"&gt;cv&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;kf&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;scoring&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;neg_mean_squared_error&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;
    &lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="n"&gt;mse_by_ncomp&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;score&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;best_k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;argmin&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mse_by_ncomp&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Best PCR components: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;best_k&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;CV MSE: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;mse_by_ncomp&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;best_k&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;,.&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Evaluate on test set
&lt;/span&gt;&lt;span class="n"&gt;regr&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_train_pc&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;best_k&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;y_train&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;pcr_test_mse&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;mean_squared_error&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y_test&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;regr&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;predict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_test_pc&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;best_k&lt;/span&gt;&lt;span class="p"&gt;]))&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;PCR test MSE (&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;best_k&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; components): &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;pcr_test_mse&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;,.&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Compare to full OLS
&lt;/span&gt;&lt;span class="n"&gt;regr_full&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;LinearRegression&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;regr_full&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_train_pc&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_train&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ols_test_mse&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;mean_squared_error&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y_test&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;regr_full&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;predict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_test_pc&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Full OLS test MSE (19 features): &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;ols_test_mse&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;,.&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;strong&gt;The result:&lt;/strong&gt; PCR with just 6 components achieves a test MSE of ~112,000, beating full OLS (test MSE ~117,000) using all 19 features. Fewer features, better predictions.&lt;/p&gt;

&lt;h3&gt;
  
  
  PCR vs PLS: The Key Difference
&lt;/h3&gt;

&lt;p&gt;Now let's try PLS, which uses the target variable during dimension reduction:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# PLS: find the best number of components via CV
&lt;/span&gt;&lt;span class="n"&gt;pls_mse&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;pls&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;PLSRegression&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_components&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;score&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="nf"&gt;cross_val_score&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
        &lt;span class="n"&gt;pls&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nf"&gt;scale&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_train&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;y_train&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;to_numpy&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt;
        &lt;span class="n"&gt;cv&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;kf&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;scoring&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;neg_mean_squared_error&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;
    &lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="n"&gt;pls_mse&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;score&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;best_pls_k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;argmin&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pls_mse&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;
&lt;span class="n"&gt;pls_best&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;PLSRegression&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_components&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;pls_best&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;scale&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_train&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;y_train&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;pls_test_mse&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;mean_squared_error&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y_test&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pls_best&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;predict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;scale&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_test&lt;/span&gt;&lt;span class="p"&gt;)))&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;PLS test MSE (2 components): &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;pls_test_mse&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;,.&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;strong&gt;PLS with just 2 components&lt;/strong&gt; achieves a test MSE of ~105,000, beating both PCR and OLS. That is the power of supervised dimension reduction: PLS finds the directions that matter for the target, not just the directions of maximum variance.&lt;/p&gt;

&lt;h2&gt;
  
  
  What Just Happened?
&lt;/h2&gt;

&lt;p&gt;Both methods solve the same problem: your 19 features are correlated (career stats like CAtBat, CHits, CRuns all move together), so fitting a separate coefficient for each one leads to noisy, unstable estimates. The solution is to compress correlated features into a smaller set of &lt;strong&gt;components&lt;/strong&gt; before regressing.&lt;/p&gt;

&lt;p&gt;The difference is &lt;em&gt;how&lt;/em&gt; they choose those components.&lt;/p&gt;

&lt;h3&gt;
  
  
  PCR: Unsupervised, Then Regress
&lt;/h3&gt;

&lt;p&gt;PCR works in two steps:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;PCA&lt;/strong&gt; finds the directions of maximum variance in &lt;code&gt;$X$&lt;/code&gt;, ignoring &lt;code&gt;$y$&lt;/code&gt; entirely&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Linear regression&lt;/strong&gt; fits &lt;code&gt;$y$&lt;/code&gt; on the top &lt;code&gt;$k$&lt;/code&gt; principal components
&lt;/li&gt;
&lt;/ol&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Step 1: PCA finds directions of maximum variance
&lt;/span&gt;&lt;span class="n"&gt;pca&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;PCA&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;X_train_pc&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pca&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fit_transform&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;scale&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_train&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;  &lt;span class="c1"&gt;# 19 features → 19 PCs
&lt;/span&gt;
&lt;span class="c1"&gt;# Step 2: Regress salary on just the first k PCs
&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;6&lt;/span&gt;
&lt;span class="n"&gt;regr&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;LinearRegression&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;regr&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_train_pc&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;y_train&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The first principal component captures the direction along which the features vary the most. In our Hitters data, PC1 captures 39.9% of the total variance, and by PC7 we're at 93.4%.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F39iav845wbfyqmfdq7yq.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F39iav845wbfyqmfdq7yq.webp" alt="Explained variance per principal component (bars) and cumulative variance (line). The first 7 components capture over 93% of the total variance in the 19 features." width="800" height="417"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;But here's the catch: the directions of maximum variance in &lt;code&gt;$X$&lt;/code&gt; are not necessarily the directions most useful for predicting &lt;code&gt;$y$&lt;/code&gt;. PC1 might capture the spread between high-career and low-career players, but if salary depends more on a subtle interaction between recent performance and league, that signal could be buried in PC8 or PC12.&lt;/p&gt;

&lt;h3&gt;
  
  
  PLS: Supervised from the Start
&lt;/h3&gt;

&lt;p&gt;PLS finds directions that simultaneously:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;Explain variance in &lt;code&gt;$X$&lt;/code&gt; (like PCA)&lt;/li&gt;
&lt;li&gt;Correlate with &lt;code&gt;$y$&lt;/code&gt; (unlike PCA)
&lt;/li&gt;
&lt;/ul&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# PLS finds directions that maximise covariance between X and y
&lt;/span&gt;&lt;span class="n"&gt;pls&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;PLSRegression&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_components&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;pls&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;scale&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_train&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;y_train&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;predictions&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pls&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;predict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;scale&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_test&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This is why PLS needs only 2 components where PCR needs 6. PLS searches directly for the features that predict salary, while PCR has to hope that the high-variance directions in &lt;code&gt;$X$&lt;/code&gt; also happen to predict &lt;code&gt;$y$&lt;/code&gt;.&lt;/p&gt;

&lt;h3&gt;
  
  
  Choosing the Number of Components
&lt;/h3&gt;

&lt;p&gt;Both methods use 10-fold cross-validation to select the number of components:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;matplotlib.pyplot&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;

&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;ax1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;subplots&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;12&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

&lt;span class="c1"&gt;# PCR
&lt;/span&gt;&lt;span class="n"&gt;ax1&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;20&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;mse_by_ncomp&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;-o&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;steelblue&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;markersize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax1&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;axvline&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;best_k&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;red&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;linestyle&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;--&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.7&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Best: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;best_k&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax1&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_xlabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Number of Components&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax1&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_ylabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;10-Fold CV MSE&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax1&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_title&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;PCR&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax1&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;legend&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;ax1&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;grid&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# PLS
&lt;/span&gt;&lt;span class="n"&gt;ax2&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;20&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;pls_mse&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;-s&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;darkorange&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;markersize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax2&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;axvline&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;green&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;linestyle&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;--&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.7&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Selected: 2 (parsimonious)&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax2&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;axvline&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;best_pls_k&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;red&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;linestyle&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;:&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;CV min: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;best_pls_k&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax2&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_xlabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Number of Components&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax2&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_ylabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;10-Fold CV MSE&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax2&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_title&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;PLS&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax2&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;legend&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;fontsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;9&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax2&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;grid&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tight_layout&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;show&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F2yslidjyargov31umd5p.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F2yslidjyargov31umd5p.webp" alt="Cross-validation MSE curves for PCR (left, minimum at 6 components) and PLS (right, CV minimum at 11 but 2 selected for parsimony). PLS reaches competitive performance with far fewer components." width="800" height="328"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The PCR curve dips at 6 components and rises again: adding noisy components &lt;em&gt;hurts&lt;/em&gt; predictions. The PLS curve is more interesting: the strict CV minimum is at 11 components, but 2 components achieve nearly the same MSE (143,564 vs 142,554). We select 2 because the simpler model generalises better on the test set (MSE 104,839 vs 106,891 with 11). This is a common pattern: when the CV curve is flat near the minimum, prefer the simpler model.&lt;/p&gt;

&lt;h3&gt;
  
  
  A Peek Inside the Components
&lt;/h3&gt;

&lt;p&gt;What do these principal components actually capture? The PCA loadings reveal which original features contribute to each component:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;loadings&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pd&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;DataFrame&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;pca&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;components_&lt;/span&gt;&lt;span class="p"&gt;[:&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;T&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;columns&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;PC&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;)],&lt;/span&gt;
    &lt;span class="n"&gt;index&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;columns&lt;/span&gt;
&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;loadings&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sort_values&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;PC1&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ascending&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;round&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fglpi7ibwgf4q7n8f80sh.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fglpi7ibwgf4q7n8f80sh.webp" alt="PCA loadings heatmap showing how each of the 19 features contributes to the first 5 principal components. Career statistics (CRuns, CRBI, CHits) dominate PC1, current-season stats (AtBat, Hits, Runs) dominate PC2, and league indicators dominate PC3." width="800" height="666"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The heatmap reveals the correlation structure clearly. PC1 (39.9% of variance) is dominated by career statistics: CRuns, CRBI, CHits, CAtBat, and CHmRun all have loadings above 0.30. PC2 (21.5%) separates current-season stats (AtBat, Hits, Runs with positive loadings) from career longevity (Years with a negative loading). PC3 picks up the league indicator variables. PCA compresses these correlated groups into single components, which is exactly why dimension reduction works here.&lt;/p&gt;

&lt;h2&gt;
  
  
  Going Deeper
&lt;/h2&gt;

&lt;h3&gt;
  
  
  The Mathematics
&lt;/h3&gt;

&lt;p&gt;&lt;strong&gt;PCR&lt;/strong&gt; decomposes &lt;code&gt;$X$&lt;/code&gt; using PCA:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DX%2520%253D%2520U%2520%255CSigma%2520V%255ET" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DX%2520%253D%2520U%2520%255CSigma%2520V%255ET" alt="equation" width="124" height="22"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;where &lt;code&gt;$V$&lt;/code&gt; contains the principal component directions (eigenvectors of &lt;code&gt;$X^TX$&lt;/code&gt;). We keep only the first &lt;code&gt;$k$&lt;/code&gt; columns of &lt;code&gt;$Z = XV_k$&lt;/code&gt; and regress:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Chat%257By%257D%2520%253D%2520Z_k%2520%255Chat%257B%255Cbeta%257D_k%2520%253D%2520X%2520V_k%2520%28V_k%255ET%2520X%255ET%2520X%2520V_k%29%255E%257B-1%257D%2520V_k%255ET%2520X%255ET%2520y" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Chat%257By%257D%2520%253D%2520Z_k%2520%255Chat%257B%255Cbeta%257D_k%2520%253D%2520X%2520V_k%2520%28V_k%255ET%2520X%255ET%2520X%2520V_k%29%255E%257B-1%257D%2520V_k%255ET%2520X%255ET%2520y" alt="equation" width="418" height="29"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;This is equivalent to OLS on the reduced feature set. The key insight: since the PCs are orthogonal, the regression coefficients don't change when you add or remove components. Each component's contribution is independent.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;PLS&lt;/strong&gt; maximises the covariance between &lt;code&gt;$X$&lt;/code&gt; and &lt;code&gt;$y$&lt;/code&gt;:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Dw_1%2520%253D%2520%255Carg%255Cmax_%257B%255C%257Cw%255C%257C%253D1%257D%2520%255Ctext%257BCov%257D%28Xw%252C%255C%252C%2520y%29%2520%253D%2520%255Carg%255Cmax_%257B%255C%257Cw%255C%257C%253D1%257D%2520%28Xw%29%255ET%2520y" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Dw_1%2520%253D%2520%255Carg%255Cmax_%257B%255C%257Cw%255C%257C%253D1%257D%2520%255Ctext%257BCov%257D%28Xw%252C%255C%252C%2520y%29%2520%253D%2520%255Carg%255Cmax_%257B%255C%257Cw%255C%257C%253D1%257D%2520%28Xw%29%255ET%2520y" alt="equation" width="495" height="43"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The first PLS direction &lt;code&gt;$w_1$&lt;/code&gt; is simply &lt;code&gt;$X^T y$&lt;/code&gt; (normalised): the covariance between each feature and the target. Subsequent directions are found by deflating &lt;code&gt;$X$&lt;/code&gt; and repeating.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F8yg25ykrw159t7kp76hm.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F8yg25ykrw159t7kp76hm.webp" alt="PCR finds directions of maximum variance in X (unsupervised, then regresses on y), while PLS finds directions that maximise covariance between X and y (supervised from the start)." width="800" height="446"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3&gt;
  
  
  When PCR Wins, When PLS Wins
&lt;/h3&gt;

&lt;p&gt;&lt;strong&gt;PCR is better when:&lt;/strong&gt;&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;The high-variance directions in &lt;code&gt;$X$&lt;/code&gt; genuinely predict &lt;code&gt;$y$&lt;/code&gt; (common in spectroscopy, genomics)&lt;/li&gt;
&lt;li&gt;You have many features and few observations (PCA provides stable variance estimates)&lt;/li&gt;
&lt;li&gt;You want an unsupervised feature extraction that you can reuse across multiple targets&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;strong&gt;PLS is better when:&lt;/strong&gt;&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;The predictive signal sits in low-variance directions of &lt;code&gt;$X$&lt;/code&gt;
&lt;/li&gt;
&lt;li&gt;You have a single target and want the most efficient compression&lt;/li&gt;
&lt;li&gt;Your features include many irrelevant high-variance variables&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;In our Hitters example, PLS wins convincingly: 2 components vs 6, and lower test error. The salary signal does not align perfectly with the directions of maximum variance in the batting statistics.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Bias-Variance Tradeoff
&lt;/h3&gt;

&lt;p&gt;Both methods trade bias for lower variance:&lt;/p&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Method&lt;/th&gt;
&lt;th&gt;Components&lt;/th&gt;
&lt;th&gt;Test MSE&lt;/th&gt;
&lt;th&gt;RMSE ($k)&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;Full OLS&lt;/td&gt;
&lt;td&gt;19&lt;/td&gt;
&lt;td&gt;117,301&lt;/td&gt;
&lt;td&gt;342&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;PCR&lt;/td&gt;
&lt;td&gt;6&lt;/td&gt;
&lt;td&gt;112,167&lt;/td&gt;
&lt;td&gt;335&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;PLS&lt;/td&gt;
&lt;td&gt;2&lt;/td&gt;
&lt;td&gt;104,839&lt;/td&gt;
&lt;td&gt;324&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Ridge&lt;/td&gt;
&lt;td&gt;--&lt;/td&gt;
&lt;td&gt;99,741&lt;/td&gt;
&lt;td&gt;316&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;p&gt;Full OLS uses all 19 features but has high variance (unstable coefficients). PCR and PLS introduce some bias by discarding information, but the reduction in variance more than compensates. Ridge regression (included for comparison) achieves the lowest error by shrinking coefficients rather than discarding components.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fzwgoyjjgn3lq6d7iesax.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fzwgoyjjgn3lq6d7iesax.webp" alt="Bar chart comparing test MSE across four methods: Full OLS, PCR with 6 components, PLS with 2 components, and Ridge regression." width="800" height="515"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The practical message: when features are correlated, you rarely need all of them. The question is whether to reduce dimensions unsupervised (PCR), supervised (PLS), or regularise without reducing dimensions at all (Ridge).&lt;/p&gt;

&lt;h3&gt;
  
  
  When NOT to Use PCR or PLS
&lt;/h3&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Few features, many observations.&lt;/strong&gt; If &lt;code&gt;$p \ll n$&lt;/code&gt;, multicollinearity is less of a problem and OLS works fine.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Interpretability is critical.&lt;/strong&gt; The principal components are linear combinations of all features, so individual feature effects are obscured. If you need to say "an extra home run is worth $X in salary," use Ridge or Lasso instead.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Non-linear relationships.&lt;/strong&gt; PCR and PLS are linear methods. For non-linear patterns, consider &lt;a href="https://sesen.ai/blog/gaussian-process-regression-from-scratch" rel="noopener noreferrer"&gt;Gaussian process regression&lt;/a&gt; or tree-based models.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Sparse signals.&lt;/strong&gt; If only a few features matter and the rest are noise, Lasso (L1 regularisation) does feature &lt;em&gt;selection&lt;/em&gt; rather than feature &lt;em&gt;combination&lt;/em&gt;, which is usually more effective.&lt;/li&gt;
&lt;/ol&gt;

&lt;h2&gt;
  
  
  Deep Dive: The Papers
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Principal Component Regression
&lt;/h3&gt;

&lt;p&gt;The idea of using principal components as regression predictors dates to &lt;strong&gt;Massy (1965)&lt;/strong&gt;, &lt;a href="https://doi.org/10.1080/01621459.1965.10480810" rel="noopener noreferrer"&gt;"Principal Components Regression in Exploratory Statistical Research"&lt;/a&gt;, published in the &lt;em&gt;Journal of the American Statistical Association&lt;/em&gt;.&lt;/p&gt;

&lt;p&gt;Massy was working on marketing research problems where survey data had dozens of correlated variables. He proposed a two-step procedure: extract principal components, then regress on the top &lt;code&gt;$k$&lt;/code&gt;. His key insight:&lt;/p&gt;

&lt;blockquote&gt;
&lt;p&gt;"By using the principal components as the independent variables in the regression, we avoid the multicollinearity problem since the components are orthogonal."&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;The underlying PCA dates back further to &lt;strong&gt;Hotelling (1933)&lt;/strong&gt;, &lt;a href="https://doi.org/10.1037/h0071325" rel="noopener noreferrer"&gt;"Analysis of a complex of statistical variables into principal components"&lt;/a&gt;, &lt;em&gt;Journal of Educational Psychology&lt;/em&gt;. Hotelling formalised the eigenvalue decomposition of the covariance matrix, though the core idea appeared even earlier in Pearson (1901).&lt;/p&gt;

&lt;h3&gt;
  
  
  Partial Least Squares
&lt;/h3&gt;

&lt;p&gt;PLS was developed by &lt;strong&gt;Herman Wold&lt;/strong&gt; in the 1960s and 1970s, originally for econometrics. The foundational paper is &lt;strong&gt;Wold (1975)&lt;/strong&gt;, "Soft modelling by latent variables: the non-linear iterative partial least squares (NIPALS) approach," in &lt;em&gt;Perspectives in Probability and Statistics&lt;/em&gt;.&lt;/p&gt;

&lt;p&gt;Herman's son, &lt;strong&gt;Svante Wold&lt;/strong&gt;, later popularised PLS in chemometrics with a landmark review: &lt;strong&gt;Wold, Sjostrom &amp;amp; Eriksson (2001)&lt;/strong&gt;, &lt;a href="https://doi.org/10.1016/S0169-7439(01)00155-1" rel="noopener noreferrer"&gt;"PLS-regression: a basic tool of chemometrics"&lt;/a&gt;, &lt;em&gt;Chemometrics and Intelligent Laboratory Systems&lt;/em&gt;.&lt;/p&gt;

&lt;p&gt;The modern computational algorithm used in most implementations (including sklearn) is &lt;strong&gt;SIMPLS&lt;/strong&gt; by &lt;strong&gt;de Jong (1993)&lt;/strong&gt;, &lt;a href="https://doi.org/10.1016/0169-7439(93)85002-X" rel="noopener noreferrer"&gt;"SIMPLS: An alternative approach to partial least squares regression"&lt;/a&gt;. de Jong's algorithm computes PLS components without the iterative deflation step, making it both faster and numerically more stable.&lt;/p&gt;

&lt;h3&gt;
  
  
  The ISLR Connection
&lt;/h3&gt;

&lt;p&gt;This tutorial is based on the lab exercise in &lt;strong&gt;James, Witten, Hastie &amp;amp; Tibshirani (2021)&lt;/strong&gt;, &lt;a href="https://www.statlearning.com/" rel="noopener noreferrer"&gt;&lt;em&gt;An Introduction to Statistical Learning&lt;/em&gt;&lt;/a&gt;, Chapter 6. ISLR provides an excellent treatment of PCR and PLS in the context of the bias-variance tradeoff, alongside Ridge and Lasso regression.&lt;/p&gt;

&lt;p&gt;The Hitters dataset used here has become a standard benchmark for comparing regularisation and dimension reduction methods. With 19 correlated features, it sits in the sweet spot where these methods make a visible difference.&lt;/p&gt;

&lt;h3&gt;
  
  
  Further Reading
&lt;/h3&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;ISLR Chapter 6&lt;/strong&gt; (free online) - PCR, PLS, Ridge, and Lasso side by side&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Hastie, Tibshirani &amp;amp; Friedman (2009)&lt;/strong&gt;, &lt;em&gt;The Elements of Statistical Learning&lt;/em&gt;, Chapter 3.5 - Rigorous treatment&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Abdi (2010)&lt;/strong&gt;, &lt;a href="https://doi.org/10.1002/wics.51" rel="noopener noreferrer"&gt;"Partial least squares regression and projection on latent structure regression"&lt;/a&gt; - Excellent modern overview&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Try It Yourself
&lt;/h2&gt;

&lt;p&gt;The &lt;a href="https://colab.research.google.com/github/zhubarb/sesen_ai_ml_tutorials/blob/main/notebooks/supervised/pcr_vs_pls.ipynb" rel="noopener noreferrer"&gt;interactive notebook&lt;/a&gt; includes exercises:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Scree plot.&lt;/strong&gt; Plot the explained variance per component and the cumulative curve. How many components do you need to capture 95% of the variance?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;PLS loadings.&lt;/strong&gt; Compare the PLS weight vectors to the PCA loadings. Which features does PLS prioritise that PCA does not?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Ridge vs PCR.&lt;/strong&gt; Add a Ridge regression (with &lt;code&gt;RidgeCV&lt;/code&gt;) to the comparison. In what sense is Ridge a "soft" version of PCR?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Log-transform the target.&lt;/strong&gt; Salary is right-skewed. Does predicting &lt;code&gt;$\log(\text{Salary})$&lt;/code&gt; change which method wins?&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;Understanding PCR and PLS builds directly on &lt;a href="https://sesen.ai/blog/linear-regression-five-ways" rel="noopener noreferrer"&gt;linear regression&lt;/a&gt; and connects to &lt;a href="https://sesen.ai/blog/from-mle-to-bayesian-inference" rel="noopener noreferrer"&gt;Bayesian inference&lt;/a&gt; through the regularisation-as-prior interpretation. When the linear assumption breaks down, &lt;a href="https://sesen.ai/blog/gaussian-process-regression-from-scratch" rel="noopener noreferrer"&gt;Gaussian process regression&lt;/a&gt; offers a non-parametric alternative that handles high-dimensional inputs gracefully.&lt;/p&gt;

&lt;h2&gt;
  
  
  Interactive Tools
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/regression-playground" rel="noopener noreferrer"&gt;Regression Playground&lt;/a&gt; — Fit and compare regression models interactively in the browser&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Related Posts
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/linear-regression-five-ways" rel="noopener noreferrer"&gt;Linear Regression Five Ways&lt;/a&gt; — The foundation both PCR and PLS build on&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/lda-vs-pca-supervised-unsupervised-dimensionality-reduction" rel="noopener noreferrer"&gt;LDA vs PCA: Supervised vs Unsupervised Dimensionality Reduction&lt;/a&gt; — The classification counterpart to PCR vs PLS&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/from-mle-to-bayesian-inference" rel="noopener noreferrer"&gt;From MLE to Bayesian Inference&lt;/a&gt; — How regularisation connects to Bayesian priors&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/gaussian-process-regression-from-scratch" rel="noopener noreferrer"&gt;Gaussian Process Regression from Scratch&lt;/a&gt; — A non-parametric alternative when linearity breaks down&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Frequently Asked Questions
&lt;/h2&gt;

&lt;h3&gt;
  
  
  What is the key difference between PCR and PLS?
&lt;/h3&gt;

&lt;p&gt;PCR finds directions of maximum variance in the features without considering the target variable, then regresses on those directions. PLS finds directions that maximise the covariance between the features and the target simultaneously. Because PLS is supervised from the start, it typically needs fewer components to achieve the same predictive performance.&lt;/p&gt;

&lt;h3&gt;
  
  
  When should I use PCR instead of PLS?
&lt;/h3&gt;

&lt;p&gt;PCR is preferable when the high-variance directions in your features genuinely predict the target, which is common in spectroscopy and genomics. It is also useful when you want an unsupervised feature extraction that can be reused across multiple target variables. PLS is better when the predictive signal sits in low-variance directions or when many high-variance features are irrelevant to the outcome.&lt;/p&gt;

&lt;h3&gt;
  
  
  How do I choose the number of components for PCR or PLS?
&lt;/h3&gt;

&lt;p&gt;Use k-fold cross-validation to evaluate predictive performance at each number of components and select the value that minimises the cross-validation error. When the error curve is flat near the minimum, prefer the simpler model with fewer components, as it tends to generalise better on unseen data.&lt;/p&gt;

&lt;h3&gt;
  
  
  Why did PLS with 2 components beat PCR with 6 components on the Hitters dataset?
&lt;/h3&gt;

&lt;p&gt;The salary signal in the baseball data does not align well with the directions of maximum variance in the batting statistics. Career statistics dominate the first few principal components, but salary depends on a subtler combination of recent performance and league factors. PLS finds these salary-relevant directions directly, so it needs far fewer components.&lt;/p&gt;

&lt;h3&gt;
  
  
  How does PCR compare to Ridge regression?
&lt;/h3&gt;

&lt;p&gt;Both methods address multicollinearity, but in different ways. PCR discards the least important principal components entirely, introducing a hard cutoff. Ridge regression shrinks all coefficients towards zero without discarding any, acting as a soft version of dimension reduction. Ridge often achieves lower test error because it retains some information from every direction.&lt;/p&gt;

&lt;h3&gt;
  
  
  Can I interpret individual feature effects with PCR or PLS?
&lt;/h3&gt;

&lt;p&gt;Not directly. The components are linear combinations of all original features, so individual feature effects are obscured. If you need to say that a specific feature is worth a certain amount, use Ridge or Lasso regression instead, which produce interpretable coefficients for each original variable.&lt;/p&gt;

</description>
      <category>supervisedlearning</category>
      <category>statistics</category>
    </item>
    <item>
      <title>Text Classification from Scratch: TF-IDF and Naive Bayes</title>
      <dc:creator>Berkan Sesen</dc:creator>
      <pubDate>Fri, 17 Apr 2026 12:46:47 +0000</pubDate>
      <link>https://dev.to/berkan_sesen/text-classification-from-scratch-tf-idf-and-naive-bayes-3pff</link>
      <guid>https://dev.to/berkan_sesen/text-classification-from-scratch-tf-idf-and-naive-bayes-3pff</guid>
      <description>&lt;p&gt;Every morning, your inbox separates spam from real email. News apps sort articles into sports, tech, and politics. Customer support systems route tickets to the right team. Behind all of these is text classification: teaching a machine to read a document and assign it a category.&lt;/p&gt;

&lt;p&gt;The building blocks are simpler than you might expect. You need a way to convert text into numbers (TF-IDF), a classifier that works well with sparse, high-dimensional data (Naive Bayes), and a few lines of code to tie them together. No deep learning, no GPUs, no embeddings.&lt;/p&gt;

&lt;p&gt;By the end of this post, you'll classify news articles into 20 categories with 77% accuracy using just 10 lines of Python, then push that to 84% with hyperparameter tuning. You'll understand exactly how TF-IDF works and why the "naive" independence assumption in Naive Bayes is a feature, not a bug.&lt;/p&gt;

&lt;h2&gt;
  
  
  Let's Build It
&lt;/h2&gt;

&lt;p&gt;Click the badge to open the interactive notebook:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://colab.research.google.com/github/zhubarb/sesen_ai_ml_tutorials/blob/main/notebooks/nlp/tfidf_naive_bayes.ipynb" rel="noopener noreferrer"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fcolab.research.google.com%2Fassets%2Fcolab-badge.svg" alt="Open In Colab" width="117" height="20"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Here's the complete classifier. We use scikit-learn's 20 Newsgroups dataset, which contains around 18,000 posts across 20 topics, from computer graphics to religion to space exploration:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn.datasets&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;fetch_20newsgroups&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn.feature_extraction.text&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;CountVectorizer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;TfidfTransformer&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn.naive_bayes&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;MultinomialNB&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn.pipeline&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Pipeline&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn.metrics&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;accuracy_score&lt;/span&gt;

&lt;span class="c1"&gt;# Load training and test data
&lt;/span&gt;&lt;span class="n"&gt;twenty_train&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;fetch_20newsgroups&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;subset&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;train&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shuffle&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;twenty_test&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;fetch_20newsgroups&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;subset&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;test&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shuffle&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Build the pipeline: raw text → word counts → TF-IDF → Naive Bayes
&lt;/span&gt;&lt;span class="n"&gt;text_clf&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;Pipeline&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;
    &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;vect&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nc"&gt;CountVectorizer&lt;/span&gt;&lt;span class="p"&gt;()),&lt;/span&gt;
    &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;tfidf&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nc"&gt;TfidfTransformer&lt;/span&gt;&lt;span class="p"&gt;()),&lt;/span&gt;
    &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;clf&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nc"&gt;MultinomialNB&lt;/span&gt;&lt;span class="p"&gt;()),&lt;/span&gt;
&lt;span class="p"&gt;])&lt;/span&gt;

&lt;span class="c1"&gt;# Train and evaluate
&lt;/span&gt;&lt;span class="n"&gt;text_clf&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;twenty_train&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;twenty_train&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;predicted&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;text_clf&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;predict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;twenty_test&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Accuracy: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="nf"&gt;accuracy_score&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;twenty_test&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;predicted&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="o"&gt;%&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="c1"&gt;# Accuracy: 77.4%
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;With 10 lines of modelling code, we classify documents into one of 20 categories at 77.4% accuracy on unseen data. Random guessing would give 5%.&lt;/p&gt;

&lt;p&gt;Let's test it on fresh sentences the model has never seen:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;docs_new&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;OpenGL shading techniques for real-time rendering&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;The Detroit Tigers signed a new pitcher today&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;NASA launched the James Webb telescope last year&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Is there evidence for the existence of God?&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;span class="p"&gt;]&lt;/span&gt;

&lt;span class="n"&gt;predicted_new&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;text_clf&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;predict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;docs_new&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;doc&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;category&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;zip&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;docs_new&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;predicted_new&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;twenty_train&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;target_names&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;category&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="o"&gt;&amp;gt;&lt;/span&gt;&lt;span class="mi"&gt;28&lt;/span&gt;&lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;  ←  &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;doc&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;            comp.graphics  ←  OpenGL shading techniques for real-time rendering
        rec.sport.baseball  ←  The Detroit Tigers signed a new pitcher today
                 sci.space  ←  NASA launched the James Webb telescope last year
    soc.religion.christian  ←  Is there evidence for the existence of God?
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The model correctly identifies the topic of each sentence. It works by finding which words are most characteristic of each category.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F6d7j0fw1g64gfr0odiqn.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F6d7j0fw1g64gfr0odiqn.webp" alt="Confusion matrix heatmap for the Naive Bayes classifier on 20 Newsgroups. The diagonal shows correct predictions; off-diagonal cells reveal common confusions between related topics." width="800" height="707"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The confusion matrix reveals where the classifier struggles. Related categories like &lt;code&gt;comp.sys.ibm.pc.hardware&lt;/code&gt; and &lt;code&gt;comp.sys.mac.hardware&lt;/code&gt; (both about computer hardware) are frequently confused, as are &lt;code&gt;talk.religion.misc&lt;/code&gt; and &lt;code&gt;soc.religion.christian&lt;/code&gt;. These make intuitive sense: documents about Mac hardware and PC hardware use very similar vocabulary.&lt;/p&gt;

&lt;h2&gt;
  
  
  What Just Happened?
&lt;/h2&gt;

&lt;p&gt;Three components work in sequence: CountVectorizer turns text into word counts, TfidfTransformer re-weights those counts to highlight distinctive words, and MultinomialNB learns which words signal which categories.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fzh05hb9r806g3498dsss.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fzh05hb9r806g3498dsss.webp" alt="The text classification pipeline: raw text flows through tokenisation, word counting, TF-IDF weighting, and finally the Naive Bayes classifier to produce a category prediction." width="800" height="446"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3&gt;
  
  
  Step 1: Turning Text into Numbers
&lt;/h3&gt;

&lt;p&gt;A machine learning model can't read English. It needs numbers. The simplest conversion is the &lt;strong&gt;bag of words&lt;/strong&gt;: count how many times each word appears in a document, ignoring order entirely.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn.feature_extraction.text&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;CountVectorizer&lt;/span&gt;

&lt;span class="n"&gt;corpus&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;The cat sat on the mat&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;The dog sat on the log&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;The cat chased the dog&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;span class="p"&gt;]&lt;/span&gt;
&lt;span class="n"&gt;vectorizer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;CountVectorizer&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;X&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;vectorizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fit_transform&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;corpus&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;vectorizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;get_feature_names_out&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;
&lt;span class="c1"&gt;# ['cat', 'chased', 'dog', 'log', 'mat', 'on', 'sat', 'the']
&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;toarray&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;
&lt;span class="c1"&gt;# [[1, 0, 0, 0, 1, 1, 1, 2],
#  [0, 0, 1, 1, 0, 1, 1, 2],
#  [1, 1, 1, 0, 0, 0, 0, 2]]
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Each row is a document. Each column is a word from the vocabulary. The value is the word count. Notice that "the" always gets a count of 2, regardless of the document. It's everywhere, so it carries no information about which document you're looking at.&lt;/p&gt;

&lt;p&gt;On the 20 Newsgroups training set, CountVectorizer discovers around 130,000 unique tokens. Each document becomes a vector of 130,000 dimensions, mostly zeros (since any single post uses only a tiny fraction of the full vocabulary).&lt;/p&gt;

&lt;h3&gt;
  
  
  Step 2: Weighting Words That Matter
&lt;/h3&gt;

&lt;p&gt;Not all words are equally informative. Words like "the", "is", and "a" appear in every document. What we want are words that are common within a specific category but rare overall. This is exactly what &lt;strong&gt;TF-IDF&lt;/strong&gt; (Term Frequency, Inverse Document Frequency) captures.&lt;/p&gt;

&lt;p&gt;The weight for word &lt;code&gt;$t$&lt;/code&gt; in document &lt;code&gt;$d$&lt;/code&gt; is:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Ctext%257BTF-IDF%257D%28t%252C%2520d%29%2520%253D%2520%255Ctext%257BTF%257D%28t%252C%2520d%29%2520%255Ctimes%2520%255Ctext%257BIDF%257D%28t%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Ctext%257BTF-IDF%257D%28t%252C%2520d%29%2520%253D%2520%255Ctext%257BTF%257D%28t%252C%2520d%29%2520%255Ctimes%2520%255Ctext%257BIDF%257D%28t%29" alt="equation" width="354" height="25"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Where:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;TF&lt;/strong&gt; (term frequency) = how often the word appears in this document&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;IDF&lt;/strong&gt; (inverse document frequency) = &lt;code&gt;$\log\!\frac{1+N}{1+n_t}+1$&lt;/code&gt;, where &lt;code&gt;$N$&lt;/code&gt; is the total number of documents and &lt;code&gt;$n_t$&lt;/code&gt; is the number of documents containing word &lt;code&gt;$t$&lt;/code&gt;
&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;A word that appears in every document gets a low IDF, shrinking its weight. A word that appears in only a few documents gets a high IDF, amplifying its signal.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn.feature_extraction.text&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;TfidfTransformer&lt;/span&gt;

&lt;span class="n"&gt;tfidf&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;TfidfTransformer&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;X_tfidf&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tfidf&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fit_transform&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;round&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_tfidf&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;toarray&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fb4j1bjzzhbbztenq0gn1.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fb4j1bjzzhbbztenq0gn1.webp" alt="TF-IDF heatmap for the toy corpus. Common words like " width="800" height="298"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;After TF-IDF weighting, the document vectors highlight what's distinctive about each text rather than what's common across all of them.&lt;/p&gt;

&lt;h3&gt;
  
  
  Step 3: Naive Bayes Classification
&lt;/h3&gt;

&lt;p&gt;Naive Bayes applies &lt;a href="https://sesen.ai/blog/from-mle-to-bayesian-inference" rel="noopener noreferrer"&gt;Bayes' theorem&lt;/a&gt; to classify documents. Given a document with words &lt;code&gt;$w_1, w_2, \ldots, w_n$&lt;/code&gt;, it computes:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DP%28%255Ctext%257Bclass%257D%2520%255Cmid%2520w_1%252C%2520w_2%252C%2520%255Cldots%252C%2520w_n%29%2520%255Cpropto%2520P%28%255Ctext%257Bclass%257D%29%2520%255Cprod_%257Bi%253D1%257D%255E%257Bn%257D%2520P%28w_i%2520%255Cmid%2520%255Ctext%257Bclass%257D%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DP%28%255Ctext%257Bclass%257D%2520%255Cmid%2520w_1%252C%2520w_2%252C%2520%255Cldots%252C%2520w_n%29%2520%255Cpropto%2520P%28%255Ctext%257Bclass%257D%29%2520%255Cprod_%257Bi%253D1%257D%255E%257Bn%257D%2520P%28w_i%2520%255Cmid%2520%255Ctext%257Bclass%257D%29" alt="equation" width="547" height="68"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The "naive" part is the assumption that words are &lt;strong&gt;conditionally independent&lt;/strong&gt; given the class. This is obviously wrong: the word "neural" is far more likely to appear near "network" than near "baseball". But the simplification works remarkably well in practice because:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;We only need the ranking right&lt;/strong&gt;, not the exact probabilities. If &lt;code&gt;$P(\text{sci.space} \mid \text{doc})$&lt;/code&gt; is the highest, the prediction is correct even if the probability value itself is off.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Independence errors tend to cancel out&lt;/strong&gt; across thousands of features.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;The alternative (modelling all word dependencies) is intractable&lt;/strong&gt; for vocabularies of 130,000 words.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;The &lt;code&gt;MultinomialNB&lt;/code&gt; variant uses word counts (or TF-IDF weights) as features and models &lt;code&gt;$P(w_i \mid \text{class})$&lt;/code&gt; as a multinomial distribution. The parameters are estimated via &lt;a href="https://sesen.ai/blog/maximum-likelihood-estimation-from-scratch" rel="noopener noreferrer"&gt;maximum likelihood&lt;/a&gt;: the probability of word &lt;code&gt;$w_i$&lt;/code&gt; in class &lt;code&gt;$c$&lt;/code&gt; is simply the fraction of times &lt;code&gt;$w_i$&lt;/code&gt; appears in training documents of class &lt;code&gt;$c$&lt;/code&gt;, with Laplace smoothing to handle words never seen in training.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Pipeline: Composing the Steps
&lt;/h3&gt;

&lt;p&gt;Scikit-learn's &lt;code&gt;Pipeline&lt;/code&gt; chains these three transformations so you can treat the entire workflow as a single estimator:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;text_clf&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;Pipeline&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;
    &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;vect&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nc"&gt;CountVectorizer&lt;/span&gt;&lt;span class="p"&gt;()),&lt;/span&gt;     &lt;span class="c1"&gt;# raw text → word counts
&lt;/span&gt;    &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;tfidf&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nc"&gt;TfidfTransformer&lt;/span&gt;&lt;span class="p"&gt;()),&lt;/span&gt;   &lt;span class="c1"&gt;# word counts → TF-IDF weights
&lt;/span&gt;    &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;clf&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nc"&gt;MultinomialNB&lt;/span&gt;&lt;span class="p"&gt;()),&lt;/span&gt;        &lt;span class="c1"&gt;# TF-IDF vectors → class predictions
&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;When you call &lt;code&gt;text_clf.fit(X, y)&lt;/code&gt;, it runs &lt;code&gt;CountVectorizer.fit_transform()&lt;/code&gt;, feeds the output to &lt;code&gt;TfidfTransformer.fit_transform()&lt;/code&gt;, then passes the result to &lt;code&gt;MultinomialNB.fit()&lt;/code&gt;. At prediction time, the same chain runs in sequence. This also means you can do grid search over any parameter in the pipeline using the double-underscore naming convention (&lt;code&gt;vect__ngram_range&lt;/code&gt;, &lt;code&gt;clf__alpha&lt;/code&gt;).&lt;/p&gt;

&lt;h2&gt;
  
  
  Going Deeper
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Beating the Baseline
&lt;/h3&gt;

&lt;p&gt;Naive Bayes at 77.4% is a strong starting point, but we can improve it in three ways: removing noise (stop words), capturing phrases (bigrams), and tuning the smoothing parameter.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Stop words&lt;/strong&gt; are common words ("the", "is", "at") that carry little discriminative value. Removing them reduces noise and bumps accuracy from 77.4% to 81.7%:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;text_clf_stop&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;Pipeline&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;
    &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;vect&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nc"&gt;CountVectorizer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;stop_words&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;english&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)),&lt;/span&gt;
    &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;tfidf&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nc"&gt;TfidfTransformer&lt;/span&gt;&lt;span class="p"&gt;()),&lt;/span&gt;
    &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;clf&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nc"&gt;MultinomialNB&lt;/span&gt;&lt;span class="p"&gt;()),&lt;/span&gt;
&lt;span class="p"&gt;])&lt;/span&gt;
&lt;span class="n"&gt;text_clf_stop&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;twenty_train&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;twenty_train&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;NB + stop words: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="nf"&gt;accuracy_score&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;twenty_test&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;text_clf_stop&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;predict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;twenty_test&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="o"&gt;%&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="c1"&gt;# NB + stop words: 81.7%
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;A 4-point gain for one parameter change.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Grid search&lt;/strong&gt; systematically explores combinations of pipeline parameters. The naming convention (&lt;code&gt;vect__&lt;/code&gt;, &lt;code&gt;tfidf__&lt;/code&gt;, &lt;code&gt;clf__&lt;/code&gt;) lets you reach into any pipeline step:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn.model_selection&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;GridSearchCV&lt;/span&gt;

&lt;span class="n"&gt;parameters&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;vect__ngram_range&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="p"&gt;[(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)],&lt;/span&gt;  &lt;span class="c1"&gt;# unigrams vs unigrams+bigrams
&lt;/span&gt;    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;tfidf__use_idf&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;         &lt;span class="c1"&gt;# use IDF weighting or not
&lt;/span&gt;    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;clf__alpha&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;1e-2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;1e-3&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;              &lt;span class="c1"&gt;# smoothing strength
&lt;/span&gt;&lt;span class="p"&gt;}&lt;/span&gt;

&lt;span class="n"&gt;gs_clf&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;GridSearchCV&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;text_clf&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cv&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_jobs&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;gs_clf&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;twenty_train&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;twenty_train&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Best CV score: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;gs_clf&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;best_score_&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="o"&gt;%&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Best params: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;gs_clf&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;best_params_&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Test accuracy: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="nf"&gt;accuracy_score&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;twenty_test&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;gs_clf&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;predict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;twenty_test&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="o"&gt;%&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="c1"&gt;# Best CV score: 91.6%
# Best params: {'clf__alpha': 0.001, 'tfidf__use_idf': True, 'vect__ngram_range': (1, 2)}
# Test accuracy: 83.6%
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The best configuration uses bigrams (&lt;code&gt;ngram_range=(1,2)&lt;/code&gt;), IDF weighting, and weak smoothing (&lt;code&gt;alpha=0.001&lt;/code&gt;). Bigrams capture phrases like "White House" or "hard drive" that individual words miss. The 5-fold CV score (91.6%) is higher than the test accuracy (83.6%) because cross-validation evaluates on data drawn from the same distribution as training, while the test set may contain authors, topics, or writing styles not seen during training.&lt;/p&gt;

&lt;p&gt;If you've read our &lt;a href="https://sesen.ai/blog/hyperparameter-optimization-grid-random-bayesian" rel="noopener noreferrer"&gt;hyperparameter optimisation post&lt;/a&gt;, you'll recognise grid search as the brute-force baseline. With only 8 combinations to evaluate here, it's fast enough.&lt;/p&gt;

&lt;h3&gt;
  
  
  SVM: A Stronger Classifier
&lt;/h3&gt;

&lt;p&gt;Swapping Naive Bayes for a linear SVM (support vector machine) gives a larger improvement than any amount of NB tuning:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn.linear_model&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;SGDClassifier&lt;/span&gt;

&lt;span class="n"&gt;text_clf_svm&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;Pipeline&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;
    &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;vect&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nc"&gt;CountVectorizer&lt;/span&gt;&lt;span class="p"&gt;()),&lt;/span&gt;
    &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;tfidf&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nc"&gt;TfidfTransformer&lt;/span&gt;&lt;span class="p"&gt;()),&lt;/span&gt;
    &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;clf-svm&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nc"&gt;SGDClassifier&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;hinge&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;penalty&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;l2&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                               &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e-3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_iter&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                               &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;)),&lt;/span&gt;
&lt;span class="p"&gt;])&lt;/span&gt;
&lt;span class="n"&gt;text_clf_svm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;twenty_train&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;twenty_train&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;SVM accuracy: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="nf"&gt;accuracy_score&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;twenty_test&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;text_clf_svm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;predict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;twenty_test&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="o"&gt;%&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="c1"&gt;# SVM accuracy: 82.4%
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;That's 82.4% out of the box, without any tuning. Grid search for SVM yields 83.5%, virtually identical to the tuned Naive Bayes.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fopvvu7yy9dotcwi6cpj5.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fopvvu7yy9dotcwi6cpj5.webp" alt="Accuracy comparison: Naive Bayes baseline (77.4%), NB with stop words (81.7%), SVM baseline (82.4%), NB tuned (83.6%), SVM tuned (83.5%)." width="800" height="434"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The story is clear: the biggest gains come from better feature representation (bigrams, stop word removal, IDF weighting) rather than the choice of classifier. With good features, even the "naive" model performs competitively.&lt;/p&gt;

&lt;h3&gt;
  
  
  What the Model Actually Learns
&lt;/h3&gt;

&lt;p&gt;What words does the classifier rely on? Raw class-conditional probabilities are dominated by common words like "the" and "of". To find truly discriminative features, we compare each word's log-probability within a class against its average across all classes:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn.feature_extraction.text&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;TfidfVectorizer&lt;/span&gt;

&lt;span class="n"&gt;tfidf_vect&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;TfidfVectorizer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;stop_words&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;english&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_df&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.9&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;min_df&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;X_tfidf&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tfidf_vect&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fit_transform&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;twenty_train&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;clf_disc&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;MultinomialNB&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_tfidf&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;twenty_train&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;feature_names&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;array&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tfidf_vect&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;get_feature_names_out&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;
&lt;span class="n"&gt;log_probs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;clf_disc&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;feature_log_prob_&lt;/span&gt;
&lt;span class="n"&gt;mean_log_prob&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_probs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;discriminativeness&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;log_probs&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;mean_log_prob&lt;/span&gt;

&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;category&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;enumerate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;twenty_train&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;target_names&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;top_indices&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;discriminativeness&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;argsort&lt;/span&gt;&lt;span class="p"&gt;()[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;:][::&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;category&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;, &lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;join&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;feature_names&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;top_indices&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Frw1ez38576psnc0rd6np.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Frw1ez38576psnc0rd6np.webp" alt="Most discriminative words for four categories: comp.graphics, rec.sport.baseball, sci.space, and talk.politics.mideast." width="800" height="599"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The model learns sensible patterns. &lt;code&gt;sci.space&lt;/code&gt; relies on words like "space", "orbit", and "nasa". &lt;code&gt;rec.sport.baseball&lt;/code&gt; relies on "baseball", "team", and "pitching". &lt;code&gt;talk.politics.mideast&lt;/code&gt; picks up "israel", "armenian", and "turkish". These are the words that carry the strongest evidence for each category, well beyond their background frequency.&lt;/p&gt;

&lt;h3&gt;
  
  
  Stemming: Reducing Words to Roots
&lt;/h3&gt;

&lt;p&gt;Stemming maps words to their root form ("running" to "run", "computers" to "comput"). This merges related word forms into a single feature, reducing vocabulary size:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;nltk&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;nltk.stem.snowball&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;SnowballStemmer&lt;/span&gt;

&lt;span class="n"&gt;nltk&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;download&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;punkt&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;quiet&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;stemmer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;SnowballStemmer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;english&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ignore_stopwords&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;StemmedCountVectorizer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;CountVectorizer&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;build_analyzer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;analyzer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;super&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;build_analyzer&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="k"&gt;lambda&lt;/span&gt; &lt;span class="n"&gt;doc&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;stemmer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;stem&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;w&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;w&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;analyzer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;doc&lt;/span&gt;&lt;span class="p"&gt;)]&lt;/span&gt;

&lt;span class="n"&gt;text_clf_stemmed&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;Pipeline&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;
    &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;vect&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nc"&gt;StemmedCountVectorizer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;stop_words&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;english&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)),&lt;/span&gt;
    &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;tfidf&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nc"&gt;TfidfTransformer&lt;/span&gt;&lt;span class="p"&gt;()),&lt;/span&gt;
    &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;clf&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nc"&gt;MultinomialNB&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;fit_prior&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)),&lt;/span&gt;
&lt;span class="p"&gt;])&lt;/span&gt;
&lt;span class="n"&gt;text_clf_stemmed&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;twenty_train&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;twenty_train&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;NB + stemming + stop words: &lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;
      &lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="nf"&gt;accuracy_score&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;twenty_test&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;text_clf_stemmed&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;predict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;twenty_test&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="o"&gt;%&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Stemming often gives a small additional boost. The original code uses the Snowball stemmer, a refined version of Porter's classic 1980 algorithm that handles irregular forms more gracefully.&lt;/p&gt;

&lt;h3&gt;
  
  
  When NOT to Use Bag-of-Words
&lt;/h3&gt;

&lt;p&gt;This approach has clear limitations:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Word order is lost.&lt;/strong&gt; "Dog bites man" and "man bites dog" produce the same vector. For tasks where order matters (sentiment analysis, textual entailment), you need sequence models or contextual embeddings.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Synonyms are invisible.&lt;/strong&gt; If test documents use different words for the same concepts, they won't match. Pre-trained embeddings (Word2Vec, BERT) capture semantic similarity.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Short documents suffer.&lt;/strong&gt; With only a few words, the sparse vector is too noisy for reliable classification. Transformer models handle short texts much better.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Scalability ceiling.&lt;/strong&gt; As the number of overlapping categories grows, the independence assumption becomes more costly.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;For many practical applications, TF-IDF with Naive Bayes remains hard to beat when you factor in the ratio of performance to complexity. It trains in seconds, requires no GPU, and produces interpretable results.&lt;/p&gt;

&lt;h2&gt;
  
  
  Where This Comes From
&lt;/h2&gt;

&lt;h3&gt;
  
  
  McCallum &amp;amp; Nigam (1998)
&lt;/h3&gt;

&lt;p&gt;The foundational paper for Naive Bayes text classification is &lt;strong&gt;McCallum, A. &amp;amp; Nigam, K. (1998)&lt;/strong&gt; &lt;a href="https://www.cs.cmu.edu/~knigam/papers/multinomial-aaaiws98.pdf" rel="noopener noreferrer"&gt;"A Comparison of Event Models for Naive Bayes Text Classification"&lt;/a&gt;, presented at the AAAI Workshop on Learning for Text Categorization.&lt;/p&gt;

&lt;p&gt;They compared two Naive Bayes variants for text:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Multi-variate Bernoulli&lt;/strong&gt;: each word is a binary feature (present or absent). This is &lt;code&gt;BernoulliNB&lt;/code&gt; in scikit-learn.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Multinomial&lt;/strong&gt;: each word is a count feature. This is the &lt;code&gt;MultinomialNB&lt;/code&gt; our pipeline uses.&lt;/li&gt;
&lt;/ul&gt;

&lt;blockquote&gt;
&lt;p&gt;"We find that the multinomial model is almost uniformly superior, especially for large vocabulary sizes."&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;The multinomial model works better because it uses word frequency information. A document mentioning "baseball" 15 times is stronger evidence for &lt;code&gt;rec.sport.baseball&lt;/code&gt; than one mentioning it once. The Bernoulli model discards this frequency signal entirely.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F26alprhgzok3pekyk8by.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F26alprhgzok3pekyk8by.webp" alt="Comparison of the two Naive Bayes event models for text: the multivariate Bernoulli model uses binary word presence, while the multinomial model uses word counts, capturing frequency information." width="800" height="446"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3&gt;
  
  
  The Multinomial Model
&lt;/h3&gt;

&lt;p&gt;Formally, the predicted class for a document &lt;code&gt;$d$&lt;/code&gt; is:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Chat%257Bc%257D%2520%253D%2520%255Carg%255Cmax_c%2520%255Cleft%255B%255Clog%2520P%28c%29%2520%252B%2520%255Csum_%257Bi%253D1%257D%255E%257B%257CV%257C%257D%2520n_i%28d%29%2520%255Clog%2520P%28w_i%2520%255Cmid%2520c%29%255Cright%255D" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Chat%257Bc%257D%2520%253D%2520%255Carg%255Cmax_c%2520%255Cleft%255B%255Clog%2520P%28c%29%2520%252B%2520%255Csum_%257Bi%253D1%257D%255E%257B%257CV%257C%257D%2520n_i%28d%29%2520%255Clog%2520P%28w_i%2520%255Cmid%2520c%29%255Cright%255D" alt="equation" width="497" height="90"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Where:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;code&gt;$P(c)$&lt;/code&gt; is the class prior (fraction of training documents in class &lt;code&gt;$c$&lt;/code&gt;)&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;$n_i(d)$&lt;/code&gt; is the count of word &lt;code&gt;$w_i$&lt;/code&gt; in document &lt;code&gt;$d$&lt;/code&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;$P(w_i \mid c)$&lt;/code&gt; is estimated with Laplace smoothing:&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DP%28w_i%2520%255Cmid%2520c%29%2520%253D%2520%255Cfrac%257Bn_%257Bic%257D%2520%252B%2520%255Calpha%257D%257B%255Csum_%257Bj%253D1%257D%255E%257B%257CV%257C%257D%2520n_%257Bjc%257D%2520%252B%2520%255Calpha%2520%257CV%257C%257D" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DP%28w_i%2520%255Cmid%2520c%29%2520%253D%2520%255Cfrac%257Bn_%257Bic%257D%2520%252B%2520%255Calpha%257D%257B%255Csum_%257Bj%253D1%257D%255E%257B%257CV%257C%257D%2520n_%257Bjc%257D%2520%252B%2520%255Calpha%2520%257CV%257C%257D" alt="equation" width="301" height="65"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The smoothing parameter &lt;code&gt;$\alpha$&lt;/code&gt; prevents zero probabilities for words that never appeared in a particular class during training. Our grid search found &lt;code&gt;$\alpha = 0.001$&lt;/code&gt; optimal, meaning the model trusts the training data more and smooths less aggressively than the default &lt;code&gt;$\alpha = 1.0$&lt;/code&gt;.&lt;/p&gt;

&lt;h3&gt;
  
  
  TF-IDF: Salton &amp;amp; Buckley (1988)
&lt;/h3&gt;

&lt;p&gt;TF-IDF was formalised by &lt;strong&gt;Salton, G. &amp;amp; Buckley, C. (1988)&lt;/strong&gt; &lt;a href="https://doi.org/10.1016/0306-4573(88)90021-0" rel="noopener noreferrer"&gt;"Term-weighting approaches in automatic text retrieval"&lt;/a&gt;, &lt;em&gt;Information Processing &amp;amp; Management&lt;/em&gt;. The core idea predates this work: Sparck Jones proposed inverse document frequency in 1972.&lt;/p&gt;

&lt;p&gt;Scikit-learn's variant uses:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Ctext%257BIDF%257D%28t%29%2520%253D%2520%255Clog%255C%21%255Cfrac%257B1%2520%252B%2520N%257D%257B1%2520%252B%2520n_t%257D%2520%252B%25201" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Ctext%257BIDF%257D%28t%29%2520%253D%2520%255Clog%255C%21%255Cfrac%257B1%2520%252B%2520N%257D%257B1%2520%252B%2520n_t%257D%2520%252B%25201" alt="equation" width="246" height="55"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The "+1" terms prevent division by zero and ensure no word gets zero weight. After computing TF-IDF, each document vector is L2-normalised to unit length.&lt;/p&gt;

&lt;h3&gt;
  
  
  Historical Context
&lt;/h3&gt;

&lt;p&gt;Text classification has a long lineage:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Maron (1961)&lt;/strong&gt; — First automatic text classification using probabilistic indexing&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Salton (1971)&lt;/strong&gt; — The SMART retrieval system, introducing many weighting schemes&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Sparck Jones (1972)&lt;/strong&gt; — Inverse document frequency&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Lewis (1998)&lt;/strong&gt; — The Reuters benchmark that standardised evaluation&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Joachims (1998)&lt;/strong&gt; — Showed SVMs outperform NB on text (our results confirm this: 82.4% vs 77.4%)&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;McCallum &amp;amp; Nigam (1998)&lt;/strong&gt; — Systematic comparison of NB event models&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;Today, transformer-based models (BERT, GPT) dominate text classification benchmarks. But TF-IDF with Naive Bayes remains the standard baseline for its speed, interpretability, and surprising competitiveness.&lt;/p&gt;

&lt;h3&gt;
  
  
  Further Reading
&lt;/h3&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://www.cs.cmu.edu/~knigam/papers/multinomial-aaaiws98.pdf" rel="noopener noreferrer"&gt;McCallum &amp;amp; Nigam (1998)&lt;/a&gt; — Multinomial vs Bernoulli NB for text&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://doi.org/10.1016/0306-4573(88)90021-0" rel="noopener noreferrer"&gt;Salton &amp;amp; Buckley (1988)&lt;/a&gt; — Systematic evaluation of TF-IDF variants&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://link.springer.com/chapter/10.1007/BFb0026683" rel="noopener noreferrer"&gt;Joachims (1998)&lt;/a&gt; — Text categorisation with SVMs&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Manning, Raghavan &amp;amp; Schütze (2008)&lt;/strong&gt; &lt;a href="https://nlp.stanford.edu/IR-book/" rel="noopener noreferrer"&gt;&lt;em&gt;Introduction to Information Retrieval&lt;/em&gt;&lt;/a&gt; — Free textbook covering TF-IDF, NB, and SVM for text in depth&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Try It Yourself
&lt;/h2&gt;

&lt;p&gt;The &lt;a href="https://colab.research.google.com/github/zhubarb/sesen_ai_ml_tutorials/blob/main/notebooks/nlp/tfidf_naive_bayes.ipynb" rel="noopener noreferrer"&gt;interactive notebook&lt;/a&gt; includes exercises:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Subset classification&lt;/strong&gt; — Use only 4 categories (&lt;code&gt;comp.graphics&lt;/code&gt;, &lt;code&gt;rec.sport.baseball&lt;/code&gt;, &lt;code&gt;sci.space&lt;/code&gt;, &lt;code&gt;talk.politics.mideast&lt;/code&gt;). How much does accuracy improve with fewer, more distinct categories?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Feature engineering&lt;/strong&gt; — Add &lt;code&gt;min_df=5&lt;/code&gt; and &lt;code&gt;max_df=0.5&lt;/code&gt; to &lt;code&gt;CountVectorizer&lt;/code&gt; to trim rare and ubiquitous words. How does this affect accuracy and vocabulary size?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Bernoulli vs Multinomial&lt;/strong&gt; — Replace &lt;code&gt;MultinomialNB&lt;/code&gt; with &lt;code&gt;BernoulliNB&lt;/code&gt;. Does the McCallum &amp;amp; Nigam finding hold on this dataset?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Beyond bag-of-words&lt;/strong&gt; — Use &lt;code&gt;TfidfVectorizer&lt;/code&gt; with &lt;code&gt;sublinear_tf=True&lt;/code&gt; and character n-grams (&lt;code&gt;analyzer='char_wb'&lt;/code&gt;, &lt;code&gt;ngram_range=(3,5)&lt;/code&gt;). Character n-grams capture morphological patterns that word-level features miss.&lt;/li&gt;
&lt;/ol&gt;

&lt;h2&gt;
  
  
  Interactive Tools
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/classification-metrics-calculator" rel="noopener noreferrer"&gt;Classification Metrics Calculator&lt;/a&gt; — Compute precision, recall, F1, and other metrics from your own confusion matrix&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/bayes-theorem-calculator" rel="noopener noreferrer"&gt;Bayes' Theorem Calculator&lt;/a&gt; — Explore the Bayesian reasoning that underpins Naive Bayes classification&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Related Posts
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/maximum-likelihood-estimation-from-scratch" rel="noopener noreferrer"&gt;Maximum Likelihood Estimation from Scratch&lt;/a&gt; — The estimation method behind Naive Bayes parameter learning&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/from-mle-to-bayesian-inference" rel="noopener noreferrer"&gt;From MLE to Bayesian Inference&lt;/a&gt; — The Bayes' theorem foundation that powers Naive Bayes classification&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/hyperparameter-optimization-grid-random-bayesian" rel="noopener noreferrer"&gt;Hyperparameter Optimization: Grid vs Random vs Bayesian&lt;/a&gt; — A deeper look at grid search and smarter alternatives&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Frequently Asked Questions
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Why is Naive Bayes called "naive"?
&lt;/h3&gt;

&lt;p&gt;The "naive" refers to the conditional independence assumption: the model assumes that each word in a document is independent of every other word, given the class. This is clearly wrong (e.g. "neural" and "network" tend to co-occur), but it works surprisingly well in practice because classification only requires getting the ranking of class probabilities right, not the exact values. Independence errors tend to cancel out across thousands of features.&lt;/p&gt;

&lt;h3&gt;
  
  
  What is the difference between TF-IDF and raw word counts?
&lt;/h3&gt;

&lt;p&gt;Raw word counts treat all words equally, so common words like "the" and "is" dominate the representation despite carrying no discriminative information. TF-IDF re-weights each word by how rare it is across the entire corpus. Words that appear in many documents get downweighted, while words distinctive to a few documents get amplified. This makes the representation much more informative for classification.&lt;/p&gt;

&lt;h3&gt;
  
  
  When should I use Naive Bayes instead of a transformer model like BERT?
&lt;/h3&gt;

&lt;p&gt;Naive Bayes with TF-IDF is an excellent choice when you need fast training (seconds, not hours), interpretability (you can inspect which words drive predictions), or when labelled data is limited. It also requires no GPU. For tasks where word order matters (sentiment analysis, entailment) or where you need state-of-the-art accuracy on competitive benchmarks, transformer models will outperform it significantly.&lt;/p&gt;

&lt;h3&gt;
  
  
  What does the smoothing parameter alpha do in MultinomialNB?
&lt;/h3&gt;

&lt;p&gt;Alpha controls Laplace smoothing, which prevents zero probabilities for words that never appeared in a particular class during training. With alpha = 1.0 (the default), the model adds a pseudocount of 1 to every word-class combination. Smaller values like 0.001 trust the training data more and smooth less aggressively. The optimal value depends on your dataset and can be found through cross-validation.&lt;/p&gt;

&lt;h3&gt;
  
  
  Why does the model confuse related categories like PC hardware and Mac hardware?
&lt;/h3&gt;

&lt;p&gt;The bag-of-words representation captures which words appear in a document but not the subtle semantic differences between closely related topics. Categories like PC hardware and Mac hardware share a large portion of their vocabulary (words like "drive", "memory", "board", "system"). The model can only distinguish them by the few words unique to each category, which may not always be present in a given document.&lt;/p&gt;

&lt;h3&gt;
  
  
  Can TF-IDF handle languages other than English?
&lt;/h3&gt;

&lt;p&gt;Yes. TF-IDF is language-agnostic at its core since it operates on tokens, not linguistic structures. However, you may need to adjust tokenisation for languages without clear word boundaries (e.g. Chinese or Japanese) and consider language-specific stop word lists. Stemming and lemmatisation tools are also language-dependent, so you would need appropriate resources for your target language.&lt;/p&gt;

</description>
      <category>supervisedlearning</category>
      <category>discriminative</category>
      <category>probabilistic</category>
    </item>
    <item>
      <title>AI Experts Are Dead. Long Live the AI Experts.</title>
      <dc:creator>Berkan Sesen</dc:creator>
      <pubDate>Wed, 15 Apr 2026 07:51:05 +0000</pubDate>
      <link>https://dev.to/berkan_sesen/ai-experts-are-dead-long-live-the-ai-experts-3hg2</link>
      <guid>https://dev.to/berkan_sesen/ai-experts-are-dead-long-live-the-ai-experts-3hg2</guid>
      <description>&lt;p&gt;Last month, my eight-year-old built a Flappy Bird clone from scratch. He can't really type yet. He certainly can't write Python. What he &lt;em&gt;can&lt;/em&gt; do is talk to Claude while I whisper in his ear what to say next. Within an hour, he had a working game: a bird, pipes, a score counter, gravity. He's an "AI expert" now.&lt;/p&gt;

&lt;p&gt;And honestly? So is your dentist, your cousin's teenager, and the recruiter who just messaged you on LinkedIn. The barrier to "using AI" has collapsed to the cost of typing a sentence in English. This is genuinely wonderful. Democratisation of powerful technology is how we got the internet, smartphones, and open-source software.&lt;/p&gt;

&lt;p&gt;But there's an asymmetry hiding behind this accessibility: while &lt;em&gt;using&lt;/em&gt; AI has never been cheaper, &lt;em&gt;building&lt;/em&gt; AI has never been more expensive. Training GPT-4 cost over $100 million. Llama 3 required 24,000 GPUs running for months. The companies that can afford to train foundation models from scratch fit comfortably in a single conference room. We've democratised the interface and monopolised the engine.&lt;/p&gt;

&lt;p&gt;So where does that leave the engineers, the domain experts, the people who actually know things about medicine, law, finance, or logistics? Somewhere in between. And that somewhere has a name: &lt;strong&gt;fine-tuning&lt;/strong&gt;. For a few hundred to a few thousand dollars, you can take a foundation model and make it &lt;em&gt;yours&lt;/em&gt;, trained on your data, speaking your domain's language, following your formatting rules. Not building the engine from scratch, but tuning it to your track.&lt;/p&gt;

&lt;p&gt;By the end of this post, you'll fine-tune a model on Azure OpenAI, understand the LoRA algorithm that makes it computationally feasible, and know exactly where fine-tuning sits in the hierarchy from prompt engineering to pre-training.&lt;/p&gt;

&lt;h2&gt;
  
  
  Quick Win: Fine-Tune on Azure in 20 Lines
&lt;/h2&gt;

&lt;p&gt;Let's start with the punchline. Here's everything you need to fine-tune a GPT model on Azure OpenAI. No Colab badge here (you'll need Azure credentials), but the code itself is almost disappointingly simple.&lt;/p&gt;

&lt;h3&gt;
  
  
  Prepare Your Training Data
&lt;/h3&gt;

&lt;p&gt;Azure expects JSONL format: one JSON object per line, each containing a conversation. Here's what training data looks like for a medical coding assistant:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight json"&gt;&lt;code&gt;&lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="nl"&gt;"messages"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="p"&gt;[{&lt;/span&gt;&lt;span class="nl"&gt;"role"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="s2"&gt;"system"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nl"&gt;"content"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="s2"&gt;"You are a medical coding assistant. Map clinical descriptions to ICD-10 codes."&lt;/span&gt;&lt;span class="p"&gt;},&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="nl"&gt;"role"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="s2"&gt;"user"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nl"&gt;"content"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="s2"&gt;"Patient presents with acute bronchitis"&lt;/span&gt;&lt;span class="p"&gt;},&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="nl"&gt;"role"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="s2"&gt;"assistant"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nl"&gt;"content"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="s2"&gt;"J20.9 — Acute bronchitis, unspecified"&lt;/span&gt;&lt;span class="p"&gt;}]}&lt;/span&gt;&lt;span class="w"&gt;
&lt;/span&gt;&lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="nl"&gt;"messages"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="p"&gt;[{&lt;/span&gt;&lt;span class="nl"&gt;"role"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="s2"&gt;"system"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nl"&gt;"content"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="s2"&gt;"You are a medical coding assistant. Map clinical descriptions to ICD-10 codes."&lt;/span&gt;&lt;span class="p"&gt;},&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="nl"&gt;"role"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="s2"&gt;"user"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nl"&gt;"content"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="s2"&gt;"Type 2 diabetes with diabetic chronic kidney disease"&lt;/span&gt;&lt;span class="p"&gt;},&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="nl"&gt;"role"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="s2"&gt;"assistant"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nl"&gt;"content"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="s2"&gt;"E11.22 — Type 2 diabetes mellitus with diabetic chronic kidney disease"&lt;/span&gt;&lt;span class="p"&gt;}]}&lt;/span&gt;&lt;span class="w"&gt;
&lt;/span&gt;&lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="nl"&gt;"messages"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="p"&gt;[{&lt;/span&gt;&lt;span class="nl"&gt;"role"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="s2"&gt;"system"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nl"&gt;"content"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="s2"&gt;"You are a medical coding assistant. Map clinical descriptions to ICD-10 codes."&lt;/span&gt;&lt;span class="p"&gt;},&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="nl"&gt;"role"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="s2"&gt;"user"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nl"&gt;"content"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="s2"&gt;"Essential hypertension"&lt;/span&gt;&lt;span class="p"&gt;},&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="nl"&gt;"role"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="s2"&gt;"assistant"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="nl"&gt;"content"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;&lt;span class="s2"&gt;"I10 — Essential (primary) hypertension"&lt;/span&gt;&lt;span class="p"&gt;}]}&lt;/span&gt;&lt;span class="w"&gt;
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Each line is a complete conversation with a system prompt, user input, and the desired assistant response.&lt;/p&gt;

&lt;h3&gt;
  
  
  Upload Data and Launch Fine-Tuning
&lt;/h3&gt;



&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;openai&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;AzureOpenAI&lt;/span&gt;

&lt;span class="n"&gt;client&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;AzureOpenAI&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;azure_endpoint&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;https://YOUR_RESOURCE.openai.azure.com/&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;api_key&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;YOUR_API_KEY&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;api_version&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;2025-03-01-preview&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Upload training file
&lt;/span&gt;&lt;span class="n"&gt;training_file&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;client&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;files&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;create&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="nb"&gt;file&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="nf"&gt;open&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;training_data.jsonl&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;rb&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
    &lt;span class="n"&gt;purpose&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;fine-tune&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Create fine-tuning job
&lt;/span&gt;&lt;span class="n"&gt;job&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;client&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fine_tuning&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;jobs&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;create&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;training_file&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;training_file&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nb"&gt;id&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;gpt-4o-mini-2024-07-18&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Job ID: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;job&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nb"&gt;id&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Status: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;job&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;status&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  Check Status and Use Your Model
&lt;/h3&gt;



&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Check progress
&lt;/span&gt;&lt;span class="n"&gt;job&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;client&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fine_tuning&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;jobs&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;retrieve&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;job&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nb"&gt;id&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Status: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;job&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;status&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# queued → running → succeeded
&lt;/span&gt;
&lt;span class="c1"&gt;# Once succeeded, deploy and use your model
&lt;/span&gt;&lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;job&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;status&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;succeeded&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;response&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;client&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;chat&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;completions&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;create&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
        &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;job&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fine_tuned_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;messages&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;
            &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;role&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;system&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;content&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;You are a medical coding assistant.&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;},&lt;/span&gt;
            &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;role&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;user&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;content&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Chronic obstructive pulmonary disease with acute exacerbation&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;},&lt;/span&gt;
        &lt;span class="p"&gt;],&lt;/span&gt;
    &lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;response&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;choices&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;message&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;content&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="c1"&gt;# J44.1 — COPD with (acute) exacerbation
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;That's it. Azure handles the infrastructure, the training loop, the checkpointing, and the deployment. You provide the data; it returns a model that speaks your domain.&lt;/p&gt;

&lt;h2&gt;
  
  
  What Just Happened?
&lt;/h2&gt;

&lt;p&gt;Behind those 20 lines, a lot happened. Let's unpack it.&lt;/p&gt;

&lt;h3&gt;
  
  
  The JSONL Format
&lt;/h3&gt;

&lt;p&gt;Each training example is a conversation in the chat completions format you already know. The key fields:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;&lt;code&gt;role: "system"&lt;/code&gt;&lt;/strong&gt;: Sets the persona. Keep this consistent across examples.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;&lt;code&gt;role: "user"&lt;/code&gt;&lt;/strong&gt;: The input your model will receive in production.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;&lt;code&gt;role: "assistant"&lt;/code&gt;&lt;/strong&gt;: The &lt;em&gt;exact&lt;/em&gt; output you want the model to learn.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;You can optionally add a &lt;code&gt;"weight": 0&lt;/code&gt; field to any message to exclude it from the loss computation. This is useful when you want the model to see context but only learn from specific responses.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Training Pipeline
&lt;/h3&gt;

&lt;p&gt;When you call &lt;code&gt;client.fine_tuning.jobs.create()&lt;/code&gt;, Azure kicks off a pipeline:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Validation&lt;/strong&gt;: Checks your JSONL for formatting errors, token limits, and minimum example counts (at least 10 examples required).&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Queuing&lt;/strong&gt;: Your job waits for GPU capacity.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Training&lt;/strong&gt;: The model is fine-tuned using LoRA (more on this shortly). Azure automatically creates checkpoints.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Results&lt;/strong&gt;: A &lt;code&gt;results.csv&lt;/code&gt; file is generated with training and validation loss at each step.&lt;/li&gt;
&lt;/ol&gt;

&lt;h3&gt;
  
  
  Hyperparameters
&lt;/h3&gt;

&lt;p&gt;You can customise the training with the &lt;code&gt;hyperparameters&lt;/code&gt; argument:&lt;/p&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Parameter&lt;/th&gt;
&lt;th&gt;Default&lt;/th&gt;
&lt;th&gt;What It Controls&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;n_epochs&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;Auto (based on dataset size)&lt;/td&gt;
&lt;td&gt;Number of passes through the training data&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;learning_rate_multiplier&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;Auto&lt;/td&gt;
&lt;td&gt;Scales the base learning rate. Higher means faster but riskier.&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;batch_size&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;Auto&lt;/td&gt;
&lt;td&gt;Examples per gradient update. Larger is more stable but uses more memory.&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;seed&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;None&lt;/td&gt;
&lt;td&gt;For reproducibility&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;



&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;job&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;client&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fine_tuning&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;jobs&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;create&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;training_file&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;training_file&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nb"&gt;id&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;gpt-4o-mini-2024-07-18&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;hyperparameters&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;{&lt;/span&gt;
        &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;n_epochs&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;learning_rate_multiplier&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mf"&gt;1.8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;batch_size&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="p"&gt;},&lt;/span&gt;
&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;For most use cases, the defaults work well. Azure auto-selects values based on your dataset size. If you want to systematically search for optimal values, &lt;a href="https://sesen.ai/blog/hyperparameter-optimization-grid-random-bayesian" rel="noopener noreferrer"&gt;hyperparameter optimisation methods&lt;/a&gt; like Bayesian optimisation can help.&lt;/p&gt;

&lt;h3&gt;
  
  
  Pricing
&lt;/h3&gt;

&lt;p&gt;Fine-tuning costs vary by model tier:&lt;/p&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Tier&lt;/th&gt;
&lt;th&gt;Training Cost&lt;/th&gt;
&lt;th&gt;Hosting Cost&lt;/th&gt;
&lt;th&gt;Best For&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;Standard&lt;/td&gt;
&lt;td&gt;Higher per-token&lt;/td&gt;
&lt;td&gt;Dedicated deployment&lt;/td&gt;
&lt;td&gt;Production workloads&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Global Standard&lt;/td&gt;
&lt;td&gt;Moderate&lt;/td&gt;
&lt;td&gt;Pay-per-use&lt;/td&gt;
&lt;td&gt;Cost-effective production&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;p&gt;Training costs are measured in tokens processed. A dataset of 1,000 examples at ~200 tokens each, trained for 3 epochs, processes about 600K tokens, typically a few dollars for smaller models like &lt;code&gt;gpt-4o-mini&lt;/code&gt;.&lt;/p&gt;

&lt;h2&gt;
  
  
  Going Deeper
&lt;/h2&gt;

&lt;h3&gt;
  
  
  What Fine-Tuning Actually Does
&lt;/h3&gt;

&lt;p&gt;A language model is a probability distribution over the next token. When you prompt GPT-4o with "The capital of France is", it assigns high probability to "Paris" and low probability to "pizza". These probabilities are determined by the model's weights, billions of numbers learned during pre-training.&lt;/p&gt;

&lt;p&gt;Fine-tuning &lt;em&gt;shifts&lt;/em&gt; these probability distributions. For our medical coding assistant, the base model might assign:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;P("J20.9") = 0.001 (it's seen ICD codes, but rarely)&lt;/li&gt;
&lt;li&gt;P("The patient has") = 0.15 (a more "natural" continuation)&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;After fine-tuning on hundreds of medical coding examples, the distribution shifts:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;P("J20.9") = 0.85&lt;/li&gt;
&lt;li&gt;P("The patient has") = 0.002&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F7i6miq5w626i6n76j02s.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F7i6miq5w626i6n76j02s.webp" alt="Before vs after fine-tuning: the probability distribution over next tokens shifts from favouring generic continuations to favouring domain-specific outputs." width="800" height="446"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The training objective is &lt;strong&gt;cross-entropy loss&lt;/strong&gt; on the assistant tokens, the same &lt;a href="https://sesen.ai/blog/maximum-likelihood-estimation-from-scratch" rel="noopener noreferrer"&gt;maximum likelihood estimation&lt;/a&gt; objective used in pre-training, just applied to a much smaller dataset. The model learns to maximise the probability of producing exactly the outputs in your training data.&lt;/p&gt;

&lt;p&gt;The gradients that update the weights flow through the same &lt;a href="https://sesen.ai/blog/backpropagation-neural-nets-from-first-principles" rel="noopener noreferrer"&gt;backpropagation&lt;/a&gt; algorithm used in pre-training. The difference is scope: pre-training processes trillions of tokens across the entire internet; fine-tuning processes thousands of tokens from your specific domain.&lt;/p&gt;

&lt;h3&gt;
  
  
  LoRA: The Algorithm Under the Hood
&lt;/h3&gt;

&lt;p&gt;Here's the problem with naive fine-tuning: GPT-4o has hundreds of billions of parameters. Updating all of them requires enormous GPU memory and risks catastrophically forgetting what the model already knows. This is where &lt;strong&gt;LoRA&lt;/strong&gt; (Low-Rank Adaptation) comes in, and it's what Azure uses under the hood.&lt;/p&gt;

&lt;p&gt;The key insight from Hu et al. (2021): when you fine-tune a large language model, the weight updates have &lt;strong&gt;low intrinsic rank&lt;/strong&gt;. In plain English, the changes needed to adapt a model to a new task live in a much smaller subspace than the full parameter space.&lt;/p&gt;

&lt;p&gt;Instead of updating a weight matrix &lt;code&gt;$W \in \mathbb{R}^{d \times k}$&lt;/code&gt; directly, LoRA decomposes the update into two smaller matrices:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fl7gwuo4t37efc95quy1e.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fl7gwuo4t37efc95quy1e.webp" alt="LoRA decomposes the weight update into two small, trainable low-rank matrices B and A, leaving the original weights frozen." width="800" height="446"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DW%27%2520%253D%2520W%2520%252B%2520%255CDelta%2520W%2520%253D%2520W%2520%252B%2520BA" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DW%27%2520%253D%2520W%2520%252B%2520%255CDelta%2520W%2520%253D%2520W%2520%252B%2520BA" alt="equation" width="294" height="22"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Where:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;code&gt;$W$&lt;/code&gt; is the original frozen weight matrix (&lt;code&gt;$d \times k$&lt;/code&gt;)&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;$B \in \mathbb{R}^{d \times r}$&lt;/code&gt; and &lt;code&gt;$A \in \mathbb{R}^{r \times k}$&lt;/code&gt; are the trainable low-rank matrices&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;$r \ll \min(d, k)$&lt;/code&gt; is the rank, typically 8, 16, or 64&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;strong&gt;The parameter reduction is dramatic.&lt;/strong&gt; Consider a weight matrix in a large transformer:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;Original: &lt;code&gt;$d = 4096, k = 4096$&lt;/code&gt; → 16.8 million parameters&lt;/li&gt;
&lt;li&gt;LoRA with &lt;code&gt;$r = 16$&lt;/code&gt;: &lt;code&gt;$(4096 \times 16) + (16 \times 4096)$&lt;/code&gt; = 131,072 parameters&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Reduction: 99.2%&lt;/strong&gt;&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;Across the entire model, LoRA typically trains only 0.1–1% of the original parameters. This means:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Memory&lt;/strong&gt;: You can fine-tune on a single GPU instead of a cluster.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Speed&lt;/strong&gt;: Fewer parameters to update means faster training.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Storage&lt;/strong&gt;: Each fine-tuned version is just the small &lt;code&gt;$B$&lt;/code&gt; and &lt;code&gt;$A$&lt;/code&gt; matrices, megabytes instead of gigabytes.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;No extra inference latency&lt;/strong&gt;: At deployment, &lt;code&gt;$BA$&lt;/code&gt; is merged back into &lt;code&gt;$W$&lt;/code&gt;. The final model has exactly the same architecture and speed as the original.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;The initialisation matters too: &lt;code&gt;$B$&lt;/code&gt; is initialised to zero and &lt;code&gt;$A$&lt;/code&gt; to random Gaussian, so &lt;code&gt;$\Delta W = BA = 0$&lt;/code&gt; at the start. Training begins from the exact pre-trained model, with no disruption.&lt;/p&gt;

&lt;h3&gt;
  
  
  The PEFT Family
&lt;/h3&gt;

&lt;p&gt;LoRA belongs to a broader family called &lt;strong&gt;Parameter-Efficient Fine-Tuning (PEFT)&lt;/strong&gt; methods. Here's how they compare:&lt;/p&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Method&lt;/th&gt;
&lt;th&gt;Trainable Params&lt;/th&gt;
&lt;th&gt;Memory&lt;/th&gt;
&lt;th&gt;Quality&lt;/th&gt;
&lt;th&gt;Inference Overhead&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;Full fine-tuning&lt;/td&gt;
&lt;td&gt;100%&lt;/td&gt;
&lt;td&gt;Very high&lt;/td&gt;
&lt;td&gt;Best&lt;/td&gt;
&lt;td&gt;None&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;LoRA&lt;/td&gt;
&lt;td&gt;0.1–1%&lt;/td&gt;
&lt;td&gt;Low&lt;/td&gt;
&lt;td&gt;Near-full&lt;/td&gt;
&lt;td&gt;None (merged)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;QLoRA&lt;/td&gt;
&lt;td&gt;0.1–1%&lt;/td&gt;
&lt;td&gt;Very low&lt;/td&gt;
&lt;td&gt;Good&lt;/td&gt;
&lt;td&gt;Slight (quantisation)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Prefix tuning&lt;/td&gt;
&lt;td&gt;Under 0.1%&lt;/td&gt;
&lt;td&gt;Very low&lt;/td&gt;
&lt;td&gt;Moderate&lt;/td&gt;
&lt;td&gt;Slight (extra tokens)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Adapters&lt;/td&gt;
&lt;td&gt;1–5%&lt;/td&gt;
&lt;td&gt;Low&lt;/td&gt;
&lt;td&gt;Good&lt;/td&gt;
&lt;td&gt;Slight (extra layers)&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;p&gt;&lt;strong&gt;LoRA&lt;/strong&gt; is the most popular because it hits the sweet spot: near-full fine-tuning quality with no inference overhead. &lt;strong&gt;QLoRA&lt;/strong&gt; adds 4-bit quantisation of the base model, reducing memory further. You can fine-tune a 65B parameter model on a single 48GB GPU. &lt;strong&gt;Prefix tuning&lt;/strong&gt; prepends learnable "virtual tokens" to the input, but quality degrades for complex tasks. &lt;strong&gt;Adapters&lt;/strong&gt; insert small trainable layers between existing transformer blocks.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Moat Spectrum
&lt;/h3&gt;

&lt;p&gt;Not all AI customisation is equal. Here's the full spectrum, from least to most defensible:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Ff36av4twfhlgwuvzqujs.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Ff36av4twfhlgwuvzqujs.webp" alt="The moat spectrum: from prompt engineering (free, no moat) through RAG and fine-tuning to pre-training (very high cost, strong moat). Fine-tuning is the sweet spot." width="800" height="446"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Approach&lt;/th&gt;
&lt;th&gt;Cost&lt;/th&gt;
&lt;th&gt;Setup Time&lt;/th&gt;
&lt;th&gt;Moat Strength&lt;/th&gt;
&lt;th&gt;When to Use&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;Prompt engineering&lt;/td&gt;
&lt;td&gt;Free&lt;/td&gt;
&lt;td&gt;Minutes&lt;/td&gt;
&lt;td&gt;None&lt;/td&gt;
&lt;td&gt;Prototyping, one-off tasks&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;RAG (Retrieval-Augmented Generation)&lt;/td&gt;
&lt;td&gt;$10s–100s/mo&lt;/td&gt;
&lt;td&gt;Days&lt;/td&gt;
&lt;td&gt;Weak (data can be copied)&lt;/td&gt;
&lt;td&gt;Need current information, citations&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Fine-tuning&lt;/td&gt;
&lt;td&gt;$100s–1,000s&lt;/td&gt;
&lt;td&gt;Days–weeks&lt;/td&gt;
&lt;td&gt;Moderate (behaviour is learned)&lt;/td&gt;
&lt;td&gt;Consistent formatting, domain tone, cost at scale&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Pre-training&lt;/td&gt;
&lt;td&gt;$10Ms–100Ms+&lt;/td&gt;
&lt;td&gt;Months&lt;/td&gt;
&lt;td&gt;Strong (architecture + data)&lt;/td&gt;
&lt;td&gt;You're OpenAI, Google, or Meta&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;p&gt;&lt;strong&gt;Prompt engineering&lt;/strong&gt; is where most people stop. It works surprisingly well but offers zero competitive moat; anyone can copy your prompt. &lt;strong&gt;RAG&lt;/strong&gt; adds your own data at inference time, which is powerful for knowledge-intensive tasks but the behaviour is still the base model's. &lt;strong&gt;Fine-tuning&lt;/strong&gt; embeds behaviour into the weights. The model doesn't need to be told &lt;em&gt;how&lt;/em&gt; to respond; it just &lt;em&gt;does&lt;/em&gt;. &lt;strong&gt;Pre-training&lt;/strong&gt; is building the engine from scratch, and unless you have a few hundred million dollars and a research lab, it's not your game.&lt;/p&gt;

&lt;h3&gt;
  
  
  When Fine-Tuning Beats Prompting
&lt;/h3&gt;

&lt;p&gt;Fine-tuning wins over prompting when you need:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Consistent output formatting.&lt;/strong&gt; JSON schemas, code conventions, structured reports. A fine-tuned model follows the format without lengthy system prompts.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Domain-specific behaviour.&lt;/strong&gt; Medical coding, legal analysis, financial compliance. The model internalises domain norms.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Tone and style.&lt;/strong&gt; Brand voice, technical writing style, conversational patterns.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Cost at scale.&lt;/strong&gt; A fine-tuned model with a short prompt is cheaper per request than a base model with a 2,000-token system prompt.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Latency.&lt;/strong&gt; Shorter prompts mean fewer input tokens to process.&lt;/li&gt;
&lt;/ul&gt;

&lt;h3&gt;
  
  
  When Fine-Tuning Loses to RAG
&lt;/h3&gt;

&lt;p&gt;Fine-tuning embeds knowledge into weights, but weights are frozen after training. If the information changes frequently (stock prices, medical guidelines, product catalogues), RAG is the better choice. RAG retrieves current documents at inference time, so the model always has access to the latest information.&lt;/p&gt;

&lt;p&gt;The best systems often combine both: fine-tune for &lt;em&gt;behaviour&lt;/em&gt; (how to respond), RAG for &lt;em&gt;knowledge&lt;/em&gt; (what to respond with).&lt;/p&gt;

&lt;h2&gt;
  
  
  Where This Comes From
&lt;/h2&gt;

&lt;h3&gt;
  
  
  LoRA: Low-Rank Adaptation of Large Language Models
&lt;/h3&gt;

&lt;p&gt;LoRA was introduced by &lt;a href="https://arxiv.org/abs/2106.09685" rel="noopener noreferrer"&gt;Hu et al. (2021)&lt;/a&gt; at Microsoft Research. The paper's central hypothesis is elegant:&lt;/p&gt;

&lt;blockquote&gt;
&lt;p&gt;"We hypothesize that the change in weights during model adaptation also has a low 'intrinsic rank,' which leads us to propose Low-Rank Adaptation (LoRA)."&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;The authors demonstrated that for GPT-3 175B, LoRA with rank 4 matched or exceeded full fine-tuning performance on multiple benchmarks while training only 0.01% of the parameters. They tested on natural language understanding (GLUE), natural language generation (E2E NLG), and instruction following. LoRA matched full fine-tuning across the board.&lt;/p&gt;

&lt;p&gt;A key practical insight from the paper: LoRA is most effective when applied to the attention weight matrices (&lt;code&gt;$W_Q$&lt;/code&gt; and &lt;code&gt;$W_V$&lt;/code&gt;), rather than the feed-forward layers. This is because attention matrices control the model's "routing" of information (which tokens attend to which), and task-specific behaviour is largely about changing these routing patterns.&lt;/p&gt;

&lt;h3&gt;
  
  
  ULMFiT: The Transfer Learning Paradigm
&lt;/h3&gt;

&lt;p&gt;Before LoRA, there was &lt;strong&gt;ULMFiT&lt;/strong&gt;. &lt;a href="https://arxiv.org/abs/1801.06146" rel="noopener noreferrer"&gt;Howard &amp;amp; Ruder (2018)&lt;/a&gt; established the now-standard paradigm: pre-train on a large corpus, then fine-tune on your task. Their key innovations, discriminative fine-tuning (different learning rates per layer) and gradual unfreezing, are the conceptual ancestors of LoRA's approach.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Broader Lineage
&lt;/h3&gt;

&lt;p&gt;The idea that pre-trained representations can be adapted to new tasks has a long history:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;ImageNet transfer learning (2012–2014).&lt;/strong&gt; Training on ImageNet, fine-tuning on medical images. Computer vision proved the concept.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;ULMFiT (2018).&lt;/strong&gt; Brought transfer learning to NLP. Demonstrated that language model pre-training produces universal features.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;BERT (2018) and GPT (2018).&lt;/strong&gt; Scaled the paradigm. Pre-train once, fine-tune for everything.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;LoRA (2021).&lt;/strong&gt; Made fine-tuning efficient enough for massive models. You don't need to update every parameter.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;Each step reduced the barrier. LoRA's contribution is making fine-tuning feasible for models so large that full fine-tuning would require a cluster.&lt;/p&gt;

&lt;h3&gt;
  
  
  Further Reading
&lt;/h3&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://arxiv.org/abs/2106.09685" rel="noopener noreferrer"&gt;Hu et al. (2021), LoRA: Low-Rank Adaptation of Large Language Models&lt;/a&gt;. The original paper. Read Section 4 for the core method.&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://arxiv.org/abs/2305.14314" rel="noopener noreferrer"&gt;Dettmers et al. (2023), QLoRA: Efficient Finetuning of Quantized LLMs&lt;/a&gt;. Combines 4-bit quantisation with LoRA.&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://arxiv.org/abs/1801.06146" rel="noopener noreferrer"&gt;Howard &amp;amp; Ruder (2018), Universal Language Model Fine-tuning for Text Classification&lt;/a&gt;. The paper that established fine-tuning as the NLP paradigm.&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/fine-tuning" rel="noopener noreferrer"&gt;Azure OpenAI Fine-Tuning Documentation&lt;/a&gt;. Official Azure docs for the code in this post.&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Try It Yourself
&lt;/h2&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;The format experiment.&lt;/strong&gt; Take a task where you want structured output (e.g., JSON with specific fields). Compare: (a) a detailed system prompt describing the format, vs (b) a fine-tuned model trained on 50 examples of the correct format. Measure how often each produces valid output.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Data quality vs quantity.&lt;/strong&gt; Create two training sets for the same task: 50 carefully curated, high-quality examples vs 500 noisy, auto-generated examples. Fine-tune on each. Quality almost always wins. This is the moat.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;The moat test.&lt;/strong&gt; Fine-tune a model on a specific domain task. Then try to replicate the same behaviour using only prompt engineering. How close can you get? Where does prompting fall short?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;LoRA from scratch.&lt;/strong&gt; Implement a toy LoRA layer in PyTorch. Freeze a pre-trained GPT-2 model, add &lt;code&gt;$BA$&lt;/code&gt; matrices to the attention layers, and fine-tune on a small text classification task. Compare the parameter count to full fine-tuning.&lt;/li&gt;
&lt;/ol&gt;

&lt;h2&gt;
  
  
  Interactive Tools
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/tools" rel="noopener noreferrer"&gt;Explore Our Free Tools&lt;/a&gt; — Hands-on calculators and visualisers for statistics, machine learning, and quantitative finance&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Related Posts
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/maximum-likelihood-estimation-from-scratch" rel="noopener noreferrer"&gt;Maximum Likelihood Estimation from Scratch&lt;/a&gt;. Fine-tuning's loss function (cross-entropy) is maximum likelihood estimation. Understanding MLE gives you intuition for what the training loop is optimising.&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/backpropagation-neural-nets-from-first-principles" rel="noopener noreferrer"&gt;Backpropagation and Neural Nets from First Principles&lt;/a&gt;. The gradient computation that makes both pre-training and fine-tuning work. LoRA reduces the number of parameters, but the gradients still flow through the same algorithm.&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/hyperparameter-optimization-grid-random-bayesian" rel="noopener noreferrer"&gt;Hyperparameter Optimisation: Grid, Random, and Bayesian&lt;/a&gt;. Fine-tuning has its own hyperparameters (learning rate, epochs, LoRA rank). This post covers systematic approaches to tuning them.&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Frequently Asked Questions
&lt;/h2&gt;

&lt;h3&gt;
  
  
  What is the difference between fine-tuning and prompt engineering?
&lt;/h3&gt;

&lt;p&gt;Prompt engineering gives instructions to a base model at inference time, while fine-tuning embeds behaviour directly into the model's weights through additional training. Fine-tuning produces more consistent outputs without lengthy system prompts and can reduce per-request costs at scale. However, prompt engineering requires zero setup and is the right starting point for prototyping.&lt;/p&gt;

&lt;h3&gt;
  
  
  How much training data do I need for fine-tuning?
&lt;/h3&gt;

&lt;p&gt;Azure OpenAI requires a minimum of 10 examples, but practical results typically need 50 to 500 high-quality examples depending on task complexity. Data quality matters far more than quantity: 50 carefully curated examples often outperform 500 noisy ones. Start small, evaluate, and add more data only if the model underperforms.&lt;/p&gt;

&lt;h3&gt;
  
  
  Does fine-tuning change the entire model?
&lt;/h3&gt;

&lt;p&gt;No. Modern fine-tuning uses LoRA (Low-Rank Adaptation), which freezes the original model weights and trains only small low-rank matrices added to the attention layers. This typically updates only 0.1 to 1% of the original parameters, making fine-tuning feasible on modest hardware while preserving the base model's general capabilities.&lt;/p&gt;

&lt;h3&gt;
  
  
  Can I fine-tune open-source models instead of using Azure?
&lt;/h3&gt;

&lt;p&gt;Yes. Open-source models like Llama and Mistral can be fine-tuned locally using libraries such as Hugging Face PEFT and QLoRA. The LoRA algorithm is the same regardless of platform. The trade-off is that you manage the infrastructure yourself, but you gain full control over the model and avoid ongoing API costs.&lt;/p&gt;

&lt;h3&gt;
  
  
  When should I use RAG instead of fine-tuning?
&lt;/h3&gt;

&lt;p&gt;Use RAG (Retrieval-Augmented Generation) when the knowledge your model needs changes frequently, such as product catalogues, medical guidelines, or pricing data. Fine-tuning embeds knowledge into frozen weights, so it cannot adapt to new information without retraining. The best systems often combine both: fine-tune for consistent behaviour and formatting, then use RAG to inject up-to-date knowledge at inference time.&lt;/p&gt;

&lt;h3&gt;
  
  
  What is QLoRA and how does it differ from LoRA?
&lt;/h3&gt;

&lt;p&gt;QLoRA combines LoRA with 4-bit quantisation of the base model, reducing memory requirements even further. With QLoRA, you can fine-tune a 65-billion parameter model on a single 48GB GPU. The trade-off is a slight quality reduction from quantisation and marginally higher inference latency compared to standard LoRA.&lt;/p&gt;

</description>
      <category>llm</category>
      <category>deeplearning</category>
      <category>optimisation</category>
    </item>
    <item>
      <title>Gaussian Process Regression: The Bayesian Approach to Curve Fitting</title>
      <dc:creator>Berkan Sesen</dc:creator>
      <pubDate>Mon, 13 Apr 2026 08:33:12 +0000</pubDate>
      <link>https://dev.to/berkan_sesen/gaussian-process-regression-the-bayesian-approach-to-curve-fitting-k7d</link>
      <guid>https://dev.to/berkan_sesen/gaussian-process-regression-the-bayesian-approach-to-curve-fitting-k7d</guid>
      <description>&lt;p&gt;You've trained a machine learning model and want to tune its hyperparameters. Each evaluation takes hours. You've tested 6 configurations so far. Where should you try next?&lt;/p&gt;

&lt;p&gt;If you read our &lt;a href="https://sesen.ai/blog/hyperparameter-optimization-grid-random-bayesian" rel="noopener noreferrer"&gt;hyperparameter optimisation post&lt;/a&gt;, you saw Bayesian optimisation solve exactly this problem. The secret weapon behind it is a &lt;strong&gt;Gaussian process&lt;/strong&gt; (GP) — a model that predicts not just a value, but &lt;em&gt;how uncertain it is about that value&lt;/em&gt;. Near your tested configurations, the GP is confident. Far away, it honestly admits "I don't know."&lt;/p&gt;

&lt;p&gt;This is regression with built-in uncertainty quantification. Unlike fitting a line or a polynomial, a GP doesn't commit to a fixed functional form. Instead, it defines a &lt;em&gt;distribution over functions&lt;/em&gt; and lets the data narrow it down. The result is a smooth prediction with a confidence band that widens where data is sparse and tightens where it's dense.&lt;/p&gt;

&lt;p&gt;By the end of this post, you'll implement GP regression from scratch with NumPy, understand how the kernel function encodes your assumptions about smoothness, and see exactly why GPs power Bayesian optimisation.&lt;/p&gt;

&lt;h2&gt;
  
  
  Let's Build It
&lt;/h2&gt;

&lt;p&gt;Click the badge to open the interactive notebook:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://colab.research.google.com/github/zhubarb/sesen_ai_ml_tutorials/blob/main/notebooks/bayesian/gaussian_process_regression.ipynb" rel="noopener noreferrer"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fcolab.research.google.com%2Fassets%2Fcolab-badge.svg" alt="Open In Colab" width="117" height="20"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Watch how the GP posterior sharpens as we feed it observations one by one — the shaded region represents 95% confidence, and it collapses around each data point:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F43gfmebdbkn06dvtzxjk.gif" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F43gfmebdbkn06dvtzxjk.gif" alt="GP posterior building up as observations are added one by one. The uncertainty band starts wide (prior) and collapses around each new data point." width="1000" height="500"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Here's the complete implementation. We'll use 6 noisy observations from &lt;a href="https://www.robots.ox.ac.uk/~mebden/reports/GPtutorial.pdf" rel="noopener noreferrer"&gt;Ebden's GP tutorial (2008)&lt;/a&gt; and predict across a dense grid:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;numpy.linalg&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;inv&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;det&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;matplotlib.pyplot&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;

&lt;span class="c1"&gt;# --- The Kernel ---
&lt;/span&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;rbf_kernel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma_f&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;l&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;Squared exponential (RBF) kernel.&lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;sigma_f&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x1&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;x2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;l&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;kernel_matrix&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;X2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma_f&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;l&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;Compute kernel matrix between two sets of points.&lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;array&lt;/span&gt;&lt;span class="p"&gt;([[&lt;/span&gt;&lt;span class="nf"&gt;rbf_kernel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma_f&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;l&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;b&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;X2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;a&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;X1&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;

&lt;span class="c1"&gt;# --- GP Prediction (Ebden Eq. 8 &amp;amp; 9) ---
&lt;/span&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;gp_predict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;X_test&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma_f&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;l&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma_n&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;Predict mean and variance at test points.&lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
    &lt;span class="n"&gt;n&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_train&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;K&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;kernel_matrix&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;X_train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma_f&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;l&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;sigma_n&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;eye&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;K_s&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;kernel_matrix&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_test&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;X_train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma_f&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;l&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;K_ss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;kernel_matrix&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_test&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;X_test&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma_f&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;l&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;sigma_n&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;eye&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_test&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

    &lt;span class="n"&gt;K_inv&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;inv&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;mu&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;K_s&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;K_inv&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;y_train&lt;/span&gt;
    &lt;span class="n"&gt;cov&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;K_ss&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;K_s&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;K_inv&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;K_s&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;T&lt;/span&gt;

    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cov&lt;/span&gt;

&lt;span class="c1"&gt;# --- Data from Ebden (2008) tutorial ---
&lt;/span&gt;&lt;span class="n"&gt;X_train&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;array&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;1.50&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;1.00&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.75&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.40&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.25&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.00&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;span class="n"&gt;y_train&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;0.55&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;array&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;3.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;2.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.6&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;1.6&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;

&lt;span class="c1"&gt;# Hyperparameters (optimised values from the tutorial: sigma_f=1.27, l=1.0)
&lt;/span&gt;&lt;span class="n"&gt;sigma_n&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;0.3&lt;/span&gt;   &lt;span class="c1"&gt;# observation noise (known from error bars)
&lt;/span&gt;&lt;span class="n"&gt;sigma_f&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;1.27&lt;/span&gt;   &lt;span class="c1"&gt;# signal standard deviation (optimised)
&lt;/span&gt;&lt;span class="n"&gt;l&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;1.0&lt;/span&gt;          &lt;span class="c1"&gt;# length-scale (optimised)
&lt;/span&gt;
&lt;span class="c1"&gt;# Predict on a dense grid
&lt;/span&gt;&lt;span class="n"&gt;X_test&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;linspace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;2.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;200&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cov&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;gp_predict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;X_test&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma_f&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;l&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma_n&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;std&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;diag&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;cov&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

&lt;span class="c1"&gt;# Plot
&lt;/span&gt;&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;subplots&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fill_between&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_test&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mf"&gt;1.96&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;std&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="mf"&gt;1.96&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;std&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;tab:blue&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;95% confidence&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_test&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;tab:blue&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;linewidth&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;GP mean&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;errorbar&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;yerr&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;sigma_n&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;fmt&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;ro&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;capsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
            &lt;span class="n"&gt;markersize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Training data&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_xlabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;x&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_ylabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;y&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;legend&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;show&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fsp5ex4uz85kj74vwhjng.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fsp5ex4uz85kj74vwhjng.webp" alt="GP posterior: smooth mean prediction (blue line) with 95% confidence band (shaded) passing through 6 noisy observations (red dots with error bars)." width="800" height="446"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The blue line is our best estimate. The shaded band is the 95% confidence interval — wide where we have no data, narrow near the observations. At &lt;code&gt;$x_* = 0.2$&lt;/code&gt;, the GP predicts &lt;code&gt;$\bar{y}_* \approx 0.98$&lt;/code&gt; with variance 0.21, closely matching Ebden's worked example (&lt;code&gt;$\bar{y}_* = 0.95$&lt;/code&gt;, &lt;code&gt;$\text{var} = 0.21$&lt;/code&gt;).&lt;/p&gt;

&lt;h2&gt;
  
  
  What Just Happened?
&lt;/h2&gt;

&lt;p&gt;Three ingredients make GP regression work: a &lt;strong&gt;kernel function&lt;/strong&gt; that encodes our smoothness assumptions, &lt;strong&gt;covariance matrices&lt;/strong&gt; that capture relationships between all points, and &lt;strong&gt;conditioning&lt;/strong&gt; that turns the prior into a posterior.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Kernel: Encoding Smoothness
&lt;/h3&gt;

&lt;p&gt;The squared exponential (RBF) kernel measures similarity between inputs:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Dk%28x%252C%2520x%27%29%2520%253D%2520%255Csigma_f%255E2%2520%255Cexp%255C%21%255Cleft%28-%255Cfrac%257B%28x%2520-%2520x%27%29%255E2%257D%257B2%255Cell%255E2%257D%255Cright%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Dk%28x%252C%2520x%27%29%2520%253D%2520%255Csigma_f%255E2%2520%255Cexp%255C%21%255Cleft%28-%255Cfrac%257B%28x%2520-%2520x%27%29%255E2%257D%257B2%255Cell%255E2%257D%255Cright%29" alt="equation" width="324" height="61"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Think of it as answering: "If I know the function value at &lt;code&gt;$x$&lt;/code&gt;, how much does that tell me about &lt;code&gt;$x'$&lt;/code&gt;?"&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Close together&lt;/strong&gt; (&lt;code&gt;$|x - x'| \ll \ell$&lt;/code&gt;): &lt;code&gt;$k \approx \sigma_f^2$&lt;/code&gt; — strong correlation, they "see" each other&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Far apart&lt;/strong&gt; (&lt;code&gt;$|x - x'| \gg \ell$&lt;/code&gt;): &lt;code&gt;$k \approx 0$&lt;/code&gt; — independent, no information shared&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;The two hyperparameters control different things:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;code&gt;$\sigma_f$&lt;/code&gt; (signal std) — the typical amplitude of the function. Higher means larger vertical swings.&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;$\ell$&lt;/code&gt; (length-scale) — how far you need to move in &lt;code&gt;$x$&lt;/code&gt; before the function value changes significantly. Short &lt;code&gt;$\ell$&lt;/code&gt; gives wiggly functions; long &lt;code&gt;$\ell$&lt;/code&gt; gives smooth ones.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;There's also &lt;code&gt;$\sigma_n$&lt;/code&gt; (noise std), capturing measurement noise: each observation &lt;code&gt;$y$&lt;/code&gt; relates to the true function via &lt;code&gt;$y = f(x) + \mathcal{N}(0, \sigma_n^2)$&lt;/code&gt;.&lt;/p&gt;

&lt;p&gt;The effect of the length-scale is dramatic:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F1e0zf0nfkbu3zalny33f.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F1e0zf0nfkbu3zalny33f.webp" alt="Three GP regression fits with different length-scales: short (l=0.3, wiggly), medium (l=1.0, smooth), and long (l=3.0, over-smoothed)." width="800" height="307"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;With &lt;code&gt;$\ell = 0.3$&lt;/code&gt;, the GP tries to wiggle through every point — it over-fits. With &lt;code&gt;$\ell = 3.0$&lt;/code&gt;, it's so smooth it can't follow the data's trend — it under-fits. The optimised &lt;code&gt;$\ell = 1.0$&lt;/code&gt; strikes the right balance.&lt;/p&gt;

&lt;h3&gt;
  
  
  Building the Covariance Matrices
&lt;/h3&gt;

&lt;p&gt;To make predictions, we compute three kernel matrices:&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;&lt;code&gt;$K$&lt;/code&gt;&lt;/strong&gt; — between all training points (Ebden Eq. 4):&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DK_%257Bij%257D%2520%253D%2520k%28x_i%252C%2520x_j%29%2520%252B%2520%255Csigma_n%255E2%2520%255Cdelta_%257Bij%257D" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DK_%257Bij%257D%2520%253D%2520k%28x_i%252C%2520x_j%29%2520%252B%2520%255Csigma_n%255E2%2520%255Cdelta_%257Bij%257D" alt="equation" width="235" height="29"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The &lt;code&gt;$\sigma_n^2$&lt;/code&gt; on the diagonal accounts for observation noise. Off-diagonal entries capture how correlated two training points are.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;&lt;code&gt;$K_*$&lt;/code&gt;&lt;/strong&gt; — between test and training points (Ebden Eq. 5):&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DK_%2A%2520%253D%2520%255Cbegin%257Bbmatrix%257D%2520k%28x_%2A%252C%2520x_1%29%2520%2526%2520k%28x_%2A%252C%2520x_2%29%2520%2526%2520%255Ccdots%2520%2526%2520k%28x_%2A%252C%2520x_n%29%2520%255Cend%257Bbmatrix%257D" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DK_%2A%2520%253D%2520%255Cbegin%257Bbmatrix%257D%2520k%28x_%2A%252C%2520x_1%29%2520%2526%2520k%28x_%2A%252C%2520x_2%29%2520%2526%2520%255Ccdots%2520%2526%2520k%28x_%2A%252C%2520x_n%29%2520%255Cend%257Bbmatrix%257D" alt="equation" width="448" height="30"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;`$K_{&lt;/strong&gt;}$`** — between test points:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DK_%257B%2A%2A%257D%2520%253D%2520k%28x_%2A%252C%2520x_%2A%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DK_%257B%2A%2A%257D%2520%253D%2520k%28x_%2A%252C%2520x_%2A%29" alt="equation" width="161" height="25"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;In our code, &lt;code&gt;kernel_matrix&lt;/code&gt; computes these. The &lt;code&gt;+ sigma_n**2 * np.eye(n)&lt;/code&gt; adds noise to the diagonal.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Prediction Equations
&lt;/h3&gt;

&lt;p&gt;The GP assumption is that training outputs &lt;code&gt;$\mathbf{y}$&lt;/code&gt; and test outputs &lt;code&gt;$y_*$&lt;/code&gt; are jointly Gaussian:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Cbegin%257Bbmatrix%257D%2520%255Cmathbf%257By%257D%2520%255C%255C%2520y_%2A%2520%255Cend%257Bbmatrix%257D%2520%255Csim%2520%255Cmathcal%257BN%257D%255C%21%255Cleft%28%255Cmathbf%257B0%257D%252C%2520%255Cbegin%257Bbmatrix%257D%2520K%2520%2526%2520K_%2A%255E%255Ctop%2520%255C%255C%2520K_%2A%2520%2526%2520K_%257B%2A%2A%257D%2520%255Cend%257Bbmatrix%257D%255Cright%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Cbegin%257Bbmatrix%257D%2520%255Cmathbf%257By%257D%2520%255C%255C%2520y_%2A%2520%255Cend%257Bbmatrix%257D%2520%255Csim%2520%255Cmathcal%257BN%257D%255C%21%255Cleft%28%255Cmathbf%257B0%257D%252C%2520%255Cbegin%257Bbmatrix%257D%2520K%2520%2526%2520K_%2A%255E%255Ctop%2520%255C%255C%2520K_%2A%2520%2526%2520K_%257B%2A%2A%257D%2520%255Cend%257Bbmatrix%257D%255Cright%29" alt="equation" width="273" height="61"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Conditioning on the observed &lt;code&gt;$\mathbf{y}$&lt;/code&gt; gives the posterior — also Gaussian (this is the beauty of Gaussians: conditioning preserves Gaussianity):&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Cbar%257By%257D_%2A%2520%253D%2520K_%2A%2520K%255E%257B-1%257D%2520%255Cmathbf%257By%257D" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Cbar%257By%257D_%2A%2520%253D%2520K_%2A%2520K%255E%257B-1%257D%2520%255Cmathbf%257By%257D" alt="equation" width="147" height="26"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Ctext%257Bvar%257D%28y_%2A%29%2520%253D%2520K_%257B%2A%2A%257D%2520-%2520K_%2A%2520K%255E%257B-1%257D%2520K_%2A%255E%255Ctop" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Ctext%257Bvar%257D%28y_%2A%29%2520%253D%2520K_%257B%2A%2A%257D%2520-%2520K_%2A%2520K%255E%257B-1%257D%2520K_%2A%255E%255Ctop" alt="equation" width="289" height="28"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The mean &lt;code&gt;$\bar{y}_*$&lt;/code&gt; is a &lt;strong&gt;weighted combination&lt;/strong&gt; of the training outputs, where the weights come from how correlated &lt;code&gt;$x_*$&lt;/code&gt; is with each training point. The variance shrinks where &lt;code&gt;$K_*$&lt;/code&gt; has large entries — near training data — and expands where &lt;code&gt;$x_*$&lt;/code&gt; is far from any observation.&lt;/p&gt;

&lt;p&gt;This is &lt;a href="https://sesen.ai/blog/from-mle-to-bayesian-inference" rel="noopener noreferrer"&gt;Bayesian inference&lt;/a&gt; in action: the prior (encoded by the kernel) gets updated by the data to produce a posterior that's both a prediction and an honest assessment of uncertainty.&lt;/p&gt;

&lt;h2&gt;
  
  
  Going Deeper
&lt;/h2&gt;

&lt;h3&gt;
  
  
  GP Prior: What the Kernel Believes Before Seeing Data
&lt;/h3&gt;

&lt;p&gt;Before any observations, the GP prior defines a distribution over functions. We can sample from it by drawing from &lt;code&gt;$\mathcal{N}(\mathbf{0}, K)$&lt;/code&gt; where &lt;code&gt;$K$&lt;/code&gt; is the kernel matrix evaluated on a grid:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;X_grid&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;linspace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;200&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;K_prior&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;kernel_matrix&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_grid&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;X_grid&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma_f&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.27&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;l&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;K_prior&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="mf"&gt;1e-8&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;eye&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_grid&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;  &lt;span class="c1"&gt;# numerical stability
&lt;/span&gt;
&lt;span class="c1"&gt;# Draw 3 random functions from the prior
&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;seed&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;samples&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;multivariate_normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_grid&lt;/span&gt;&lt;span class="p"&gt;)),&lt;/span&gt; &lt;span class="n"&gt;K_prior&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;subplots&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_grid&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;samples&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Sample &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fill_between&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_grid&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;1.96&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="mf"&gt;1.27&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;1.96&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="mf"&gt;1.27&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;grey&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;95% prior band&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_xlabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;x&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_ylabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;f(x)&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;legend&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;show&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fhgdob6wtu3jcgw8g9enu.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fhgdob6wtu3jcgw8g9enu.webp" alt="Three random functions sampled from the GP prior. All are smooth (determined by the length-scale) with amplitudes controlled by sigma_f." width="800" height="446"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Every sample is smooth — the RBF kernel enforces this. The length-scale &lt;code&gt;$\ell = 1.0$&lt;/code&gt; means the functions vary on a scale of roughly 1 unit in &lt;code&gt;$x$&lt;/code&gt;. Data will prune this infinite family down to the functions consistent with our observations.&lt;/p&gt;

&lt;h3&gt;
  
  
  Hyperparameter Optimisation: Marginal Log-Likelihood
&lt;/h3&gt;

&lt;p&gt;How did we arrive at &lt;code&gt;$\sigma_f = 1.27$&lt;/code&gt; and &lt;code&gt;$\ell = 1.0$&lt;/code&gt;? The original code optimises the &lt;strong&gt;marginal log-likelihood&lt;/strong&gt; (Ebden Eq. 10):&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Clog%2520p%28%255Cmathbf%257By%257D%2520%255Cmid%2520%255Cmathbf%257Bx%257D%252C%2520%255Cboldsymbol%257B%255Ctheta%257D%29%2520%253D%2520-%255Cfrac%257B1%257D%257B2%257D%255Cmathbf%257By%257D%255E%255Ctop%2520K%255E%257B-1%257D%255Cmathbf%257By%257D%2520-%2520%255Cfrac%257B1%257D%257B2%257D%255Clog%257CK%257C%2520-%2520%255Cfrac%257Bn%257D%257B2%257D%255Clog%25202%255Cpi" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Clog%2520p%28%255Cmathbf%257By%257D%2520%255Cmid%2520%255Cmathbf%257Bx%257D%252C%2520%255Cboldsymbol%257B%255Ctheta%257D%29%2520%253D%2520-%255Cfrac%257B1%257D%257B2%257D%255Cmathbf%257By%257D%255E%255Ctop%2520K%255E%257B-1%257D%255Cmathbf%257By%257D%2520-%2520%255Cfrac%257B1%257D%257B2%257D%255Clog%257CK%257C%2520-%2520%255Cfrac%257Bn%257D%257B2%257D%255Clog%25202%255Cpi" alt="equation" width="544" height="51"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;This balances three terms:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Data fit&lt;/strong&gt; (&lt;code&gt;$-\frac{1}{2}\mathbf{y}^\top K^{-1}\mathbf{y}$&lt;/code&gt;) — how well the model explains the observations&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Complexity penalty&lt;/strong&gt; (&lt;code&gt;$-\frac{1}{2}\log|K|$&lt;/code&gt;) — penalises overly flexible models (Occam's razor)&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Normalisation&lt;/strong&gt; (&lt;code&gt;$-\frac{n}{2}\log 2\pi$&lt;/code&gt;) — constant&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;This is &lt;a href="https://sesen.ai/blog/maximum-likelihood-estimation-from-scratch" rel="noopener noreferrer"&gt;maximum likelihood estimation&lt;/a&gt; applied to the marginal likelihood — "marginal" because we've integrated out the function values, leaving only the hyperparameters.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;scipy.optimize&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;minimize&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;neg_log_marginal_likelihood&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;X_train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma_n&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;Negative log marginal likelihood (Ebden Eq. 10).&lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
    &lt;span class="n"&gt;sigma_f&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;l&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;abs&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# ensure positive
&lt;/span&gt;    &lt;span class="n"&gt;n&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_train&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;K&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;kernel_matrix&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;X_train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma_f&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;l&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;sigma_n&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;eye&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="n"&gt;log_lik&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;y_train&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="nf"&gt;inv&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;y_train&lt;/span&gt;
               &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mf"&gt;0.5&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;det&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;K&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
               &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;n&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;pi&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;log_lik&lt;/span&gt;

&lt;span class="c1"&gt;# Optimise sigma_f and l, keeping sigma_n=0.3 fixed (matching original code)
&lt;/span&gt;&lt;span class="n"&gt;result&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;minimize&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;neg_log_marginal_likelihood&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;x0&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;  &lt;span class="c1"&gt;# initial values from original code
&lt;/span&gt;    &lt;span class="n"&gt;args&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.3&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
    &lt;span class="n"&gt;method&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Nelder-Mead&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;
&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;sigma_f_opt&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;l_opt&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;abs&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;result&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Optimised: sigma_f=&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;sigma_f_opt&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;, l=&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;l_opt&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="c1"&gt;# Optimised: sigma_f=1.34, l=1.04
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The log-likelihood landscape shows a clear optimum:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fezmtrno6ynpkumf51cye.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fezmtrno6ynpkumf51cye.webp" alt="Contour plot of log marginal likelihood over sigma_f and l, with the optimum marked at (1.34, 1.04)." width="800" height="636"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The optimiser starts at &lt;code&gt;$(\sigma_f, \ell) = (0.1, 0.5)$&lt;/code&gt; and converges to &lt;code&gt;$(1.34, 1.04)$&lt;/code&gt;, close to the tutorial's reported values of &lt;code&gt;$(1.27, 1.0)$&lt;/code&gt;. The small difference comes from Nelder-Mead finding a slightly different local optimum — the negative log-likelihoods differ by only 0.004. The complexity penalty prevents overfitting — without it, the model would shrink &lt;code&gt;$\ell$&lt;/code&gt; to interpolate every noisy observation exactly.&lt;/p&gt;

&lt;p&gt;A fully Bayesian approach would place priors on &lt;code&gt;$\sigma_f$&lt;/code&gt; and &lt;code&gt;$\ell$&lt;/code&gt; and integrate over them using &lt;a href="https://sesen.ai/blog/mcmc-metropolis-hastings-island-hopping-guide" rel="noopener noreferrer"&gt;MCMC&lt;/a&gt;. For most applications, the point estimate from marginal likelihood optimisation works well.&lt;/p&gt;

&lt;h3&gt;
  
  
  Validating with sklearn
&lt;/h3&gt;

&lt;p&gt;Let's confirm our from-scratch implementation matches scikit-learn:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn.gaussian_process&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;GaussianProcessRegressor&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn.gaussian_process.kernels&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;RBF&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;WhiteKernel&lt;/span&gt;

&lt;span class="n"&gt;kernel&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;1.27&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="nc"&gt;RBF&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;length_scale&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="nc"&gt;WhiteKernel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;noise_level&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.3&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;gpr&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;GaussianProcessRegressor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;kernel&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;kernel&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;gpr&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_train&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;y_train&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;mu_sk&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;std_sk&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;gpr&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;predict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_test&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;return_std&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Max difference in mean: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;abs&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mu&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;mu_sk&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="c1"&gt;# Max difference in mean: ~1e-14 (numerical precision)
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Our NumPy implementation produces identical results.&lt;/p&gt;

&lt;h3&gt;
  
  
  When NOT to Use Gaussian Processes
&lt;/h3&gt;

&lt;p&gt;GPs have a fundamental limitation: &lt;strong&gt;cubic scaling&lt;/strong&gt;. Computing &lt;code&gt;$K^{-1}$&lt;/code&gt; is &lt;code&gt;$O(n^3)$&lt;/code&gt; in the number of training points. With 100 points it's instant; with 1,000 it takes seconds; with 10,000 it becomes impractical.&lt;/p&gt;

&lt;p&gt;For large datasets, consider:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Sparse GP approximations&lt;/strong&gt; — use a subset of inducing points to approximate the full GP&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Random Fourier features&lt;/strong&gt; — approximate the kernel with explicit feature maps&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;&lt;a href="https://sesen.ai/blog/backpropagation-neural-nets-from-first-principles" rel="noopener noreferrer"&gt;Neural networks&lt;/a&gt;&lt;/strong&gt; — scale linearly in data but lose calibrated uncertainty&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;The RBF kernel also assumes &lt;strong&gt;stationarity&lt;/strong&gt; — the same smoothness everywhere. If your function is smooth in one region and jagged in another, you'd need a non-stationary kernel or a different model.&lt;/p&gt;

&lt;h2&gt;
  
  
  Where This Comes From
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Rasmussen &amp;amp; Williams (2006)
&lt;/h3&gt;

&lt;p&gt;The definitive reference is &lt;strong&gt;Rasmussen, C.E. &amp;amp; Williams, C.K.I. (2006)&lt;/strong&gt; &lt;a href="http://www.gaussianprocess.org/gpml/" rel="noopener noreferrer"&gt;&lt;em&gt;Gaussian Processes for Machine Learning&lt;/em&gt;&lt;/a&gt;, MIT Press. The full text is available free online.&lt;/p&gt;

&lt;p&gt;They define a Gaussian process formally:&lt;/p&gt;

&lt;blockquote&gt;
&lt;p&gt;"A Gaussian process is a collection of random variables, any finite number of which have a joint Gaussian distribution."&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;A GP is fully specified by its mean function &lt;code&gt;$m(x)$&lt;/code&gt; and covariance function &lt;code&gt;$k(x, x')$&lt;/code&gt;:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Df%28x%29%2520%255Csim%2520%255Cmathcal%257BGP%257D%255C%21%255Cleft%28m%28x%29%252C%255C%252C%2520k%28x%252C%2520x%27%29%255Cright%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Df%28x%29%2520%255Csim%2520%255Cmathcal%257BGP%257D%255C%21%255Cleft%28m%28x%29%252C%255C%252C%2520k%28x%252C%2520x%27%29%255Cright%29" alt="equation" width="278" height="26"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;We set &lt;code&gt;$m(x) = 0$&lt;/code&gt; (standard practice — the kernel handles everything), so the covariance function alone defines the GP. This is the &lt;strong&gt;function-space view&lt;/strong&gt;: we're placing a prior directly on functions, not on parameters.&lt;/p&gt;

&lt;p&gt;Rasmussen &amp;amp; Williams also present the &lt;strong&gt;weight-space view&lt;/strong&gt; (Chapter 2.1). In Bayesian linear regression with basis functions &lt;code&gt;$\phi(x)$&lt;/code&gt;, the predictive distribution is:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Df_%2A%2520%255Cmid%2520%255Cmathbf%257Bx%257D_%2A%252C%2520X%252C%2520%255Cmathbf%257By%257D%2520%255Csim%2520%255Cmathcal%257BN%257D%255C%21%255Cleft%28%255Cphi_%2A%255E%255Ctop%2520%255CSigma_p%2520%255CPhi%2520%28K%29%255E%257B-1%257D%2520%255Cmathbf%257By%257D%252C%255C%253B%2520%255Cphi_%2A%255E%255Ctop%2520%255CSigma_p%2520%255Cphi_%2A%255Cright%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Df_%2A%2520%255Cmid%2520%255Cmathbf%257Bx%257D_%2A%252C%2520X%252C%2520%255Cmathbf%257By%257D%2520%255Csim%2520%255Cmathcal%257BN%257D%255C%21%255Cleft%28%255Cphi_%2A%255E%255Ctop%2520%255CSigma_p%2520%255CPhi%2520%28K%29%255E%257B-1%257D%2520%255Cmathbf%257By%257D%252C%255C%253B%2520%255Cphi_%2A%255E%255Ctop%2520%255CSigma_p%2520%255Cphi_%2A%255Cright%29" alt="equation" width="455" height="31"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;As the number of basis functions grows to infinity, this converges to the GP formulation — the function-space and weight-space views are equivalent. This connects GPs to &lt;a href="https://sesen.ai/blog/backpropagation-neural-nets-from-first-principles" rel="noopener noreferrer"&gt;neural networks&lt;/a&gt;: Neal (1996) showed that a single-hidden-layer neural network with infinitely many hidden units converges to a GP.&lt;/p&gt;

&lt;h3&gt;
  
  
  Ebden (2008) — Our Implementation Reference
&lt;/h3&gt;

&lt;p&gt;Our implementation follows &lt;strong&gt;Ebden, M. (2008)&lt;/strong&gt; &lt;a href="https://www.robots.ox.ac.uk/~mebden/reports/GPtutorial.pdf" rel="noopener noreferrer"&gt;"Gaussian Processes for Regression: A Quick Introduction"&lt;/a&gt;, which maps directly to our code:&lt;/p&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Tutorial Equation&lt;/th&gt;
&lt;th&gt;Our Code&lt;/th&gt;
&lt;th&gt;What it computes&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;Eq. (1): RBF kernel&lt;/td&gt;
&lt;td&gt;&lt;code&gt;rbf_kernel()&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;&lt;code&gt;$k(x, x') = \sigma_f^2 \exp(-(x-x')^2 / 2\ell^2)$&lt;/code&gt;&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Eq. (4): &lt;code&gt;$K$&lt;/code&gt; matrix&lt;/td&gt;
&lt;td&gt;&lt;code&gt;kernel_matrix() + sigma_n^2 * I&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;Training covariance&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Eq. (5): &lt;code&gt;$K_*$&lt;/code&gt; and &lt;code&gt;$K_{**}$&lt;/code&gt;
&lt;/td&gt;
&lt;td&gt;
&lt;code&gt;kernel_matrix()&lt;/code&gt; calls&lt;/td&gt;
&lt;td&gt;Test-train and test-test covariance&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Eq. (8): Predictive mean&lt;/td&gt;
&lt;td&gt;&lt;code&gt;K_s @ K_inv @ y_train&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;&lt;code&gt;$\bar{y}_* = K_* K^{-1}\mathbf{y}$&lt;/code&gt;&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Eq. (9): Predictive variance&lt;/td&gt;
&lt;td&gt;&lt;code&gt;K_ss - K_s @ K_inv @ K_s.T&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;&lt;code&gt;$K_{**} - K_* K^{-1} K_*^\top$&lt;/code&gt;&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Eq. (10): Log-likelihood&lt;/td&gt;
&lt;td&gt;&lt;code&gt;neg_log_marginal_likelihood()&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;Hyperparameter optimisation&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;h3&gt;
  
  
  Historical Context
&lt;/h3&gt;

&lt;p&gt;The idea of using stochastic processes for interpolation has deep roots:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Kolmogorov (1941)&lt;/strong&gt; and &lt;strong&gt;Wiener (1949)&lt;/strong&gt; — optimal linear prediction theory&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Matheron (1963)&lt;/strong&gt; — "kriging" in geostatistics, named after D.G. Krige's mining work in South Africa&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;O'Hagan (1978)&lt;/strong&gt; — formalised the Bayesian interpretation&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;MacKay (1998)&lt;/strong&gt; — introduced GPs to the machine learning community&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Neal (1996)&lt;/strong&gt; — proved the GP-neural network connection&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Rasmussen &amp;amp; Williams (2006)&lt;/strong&gt; — the modern comprehensive treatment&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;Kriging and GP regression are mathematically identical — the geostatistics and machine learning communities developed the same ideas independently, using different vocabulary (variograms vs kernels, kriging variance vs posterior variance).&lt;/p&gt;

&lt;h3&gt;
  
  
  Further Reading
&lt;/h3&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="http://www.gaussianprocess.org/gpml/" rel="noopener noreferrer"&gt;Rasmussen &amp;amp; Williams (2006)&lt;/a&gt; — The full textbook, free online. Start with Chapter 2.&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://www.robots.ox.ac.uk/~mebden/reports/GPtutorial.pdf" rel="noopener noreferrer"&gt;Ebden (2008)&lt;/a&gt; — The concise tutorial our code follows&lt;/li&gt;
&lt;li&gt;
&lt;a href="http://www.inference.org.uk/mackay/gpB.pdf" rel="noopener noreferrer"&gt;MacKay (1998)&lt;/a&gt; — "Introduction to Gaussian Processes" — a shorter, more accessible introduction&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Our &lt;a href="https://sesen.ai/blog/hyperparameter-optimization-grid-random-bayesian" rel="noopener noreferrer"&gt;Bayesian optimisation post&lt;/a&gt;&lt;/strong&gt; — see GPs in action as the surrogate model&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Try It Yourself
&lt;/h2&gt;

&lt;p&gt;The &lt;a href="https://colab.research.google.com/github/zhubarb/sesen_ai_ml_tutorials/blob/main/notebooks/bayesian/gaussian_process_regression.ipynb" rel="noopener noreferrer"&gt;interactive notebook&lt;/a&gt; includes exercises:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Different kernels&lt;/strong&gt; — Implement the Matern 5/2 kernel and compare its predictions to the RBF. How do the confidence bands differ?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Multi-dimensional inputs&lt;/strong&gt; — Extend the GP to 2D inputs. The kernel becomes &lt;code&gt;$k(\mathbf{x}, \mathbf{x}') = \sigma_f^2 \exp(-\|\mathbf{x} - \mathbf{x}'\|^2 / 2\ell^2)$&lt;/code&gt;.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Noisy function&lt;/strong&gt; — Generate data from &lt;code&gt;$y = \sin(x) + \epsilon$&lt;/code&gt; with &lt;code&gt;$\epsilon \sim \mathcal{N}(0, 0.2^2)$&lt;/code&gt;. Fit a GP and observe how the confidence band captures the noise.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Bayesian optimisation&lt;/strong&gt; — Use your GP to implement a simple acquisition function (Expected Improvement) and optimise a 1D black-box function.&lt;/li&gt;
&lt;/ol&gt;

&lt;h2&gt;
  
  
  Interactive Tools
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/gaussian-process-playground" rel="noopener noreferrer"&gt;GP Regression Playground&lt;/a&gt; — Fit Gaussian processes to your own data and experiment with kernels in the browser&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Related Posts
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/hyperparameter-optimization-grid-random-bayesian" rel="noopener noreferrer"&gt;Hyperparameter Optimization: Grid vs Random vs Bayesian&lt;/a&gt; — See GPs as the surrogate model in Bayesian optimisation&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/from-mle-to-bayesian-inference" rel="noopener noreferrer"&gt;From MLE to Bayesian Inference&lt;/a&gt; — The Bayesian framework that underpins GP regression&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/mcmc-metropolis-hastings-island-hopping-guide" rel="noopener noreferrer"&gt;MCMC Island Hopping: Understanding Metropolis-Hastings&lt;/a&gt; — An alternative to point estimates for GP hyperparameters&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Frequently Asked Questions
&lt;/h2&gt;

&lt;h3&gt;
  
  
  What does a Gaussian process actually model?
&lt;/h3&gt;

&lt;p&gt;A Gaussian process defines a probability distribution over functions, not just over individual predictions. Any finite collection of function values is modelled as a multivariate Gaussian distribution. The kernel function specifies how correlated any two function values are, which determines the smoothness and structure of the functions the GP considers plausible.&lt;/p&gt;

&lt;h3&gt;
  
  
  What does the length-scale hyperparameter control?
&lt;/h3&gt;

&lt;p&gt;The length-scale determines how quickly the function can vary. A short length-scale allows rapid, wiggly changes and can lead to overfitting, while a long length-scale enforces slow, smooth variation and can underfit. The optimal length-scale is typically found by maximising the marginal log-likelihood, which automatically balances data fit against model complexity.&lt;/p&gt;

&lt;h3&gt;
  
  
  Why do Gaussian processes scale poorly to large datasets?
&lt;/h3&gt;

&lt;p&gt;GP prediction requires inverting the training covariance matrix, which is an O(n^3) operation. For 10,000 or more training points, this becomes computationally prohibitive. Sparse GP approximations, which use a smaller set of inducing points to approximate the full covariance, are the most common workaround.&lt;/p&gt;

&lt;h3&gt;
  
  
  How is GP regression related to Bayesian optimisation?
&lt;/h3&gt;

&lt;p&gt;In Bayesian optimisation, a GP serves as a surrogate model that approximates the expensive objective function. The GP's uncertainty estimates are critical: an acquisition function uses the predicted mean and variance to decide where to evaluate next, balancing exploitation (areas with good predicted values) and exploration (areas with high uncertainty).&lt;/p&gt;

&lt;h3&gt;
  
  
  Can I use kernels other than the RBF?
&lt;/h3&gt;

&lt;p&gt;Yes. The choice of kernel encodes your assumptions about the function. The Matern kernel allows you to control the differentiability of the function (the RBF assumes infinite differentiability). Periodic kernels capture repeating patterns. You can also combine kernels by adding or multiplying them to model functions with multiple structural components.&lt;/p&gt;

</description>
      <category>bayesian</category>
      <category>supervisedlearning</category>
      <category>probabilistic</category>
      <category>inference</category>
    </item>
    <item>
      <title>Hyperparameter Optimization: Grid vs Random vs Bayesian</title>
      <dc:creator>Berkan Sesen</dc:creator>
      <pubDate>Fri, 10 Apr 2026 08:20:40 +0000</pubDate>
      <link>https://dev.to/berkan_sesen/hyperparameter-optimization-grid-vs-random-vs-bayesian-gik</link>
      <guid>https://dev.to/berkan_sesen/hyperparameter-optimization-grid-vs-random-vs-bayesian-gik</guid>
      <description>&lt;p&gt;You've trained a Random Forest and it works — 85% accuracy out of the box. But you used the default hyperparameters. What if &lt;code&gt;n_estimators=500&lt;/code&gt; with &lt;code&gt;max_features=0.3&lt;/code&gt; and &lt;code&gt;min_samples_leaf=10&lt;/code&gt; pushes that to 91%? Only one way to find out: search.&lt;/p&gt;

&lt;p&gt;The problem is combinatorial. Our Random Forest has 4 hyperparameters. If you try 10 values for each in a grid, that's &lt;code&gt;$10^4 = 10{,}000$&lt;/code&gt; combinations. Each combination requires 5-fold cross-validation. That's 50,000 model fits — and we only have 4 dimensions. Neural networks routinely have 10–20 tunable hyperparameters, where exhaustive search is physically impossible.&lt;/p&gt;

&lt;p&gt;This post compares three strategies of increasing sophistication:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Grid Search&lt;/strong&gt; — try every combination on a predefined grid&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Random Search&lt;/strong&gt; — sample combinations at random (surprisingly effective)&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Bayesian Optimization&lt;/strong&gt; — build a model of the objective and use it to choose the next point intelligently&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;We'll run all three on the same classification task, using the same Random Forest and the same hyperparameter ranges. You'll see that for an easy problem, all three reach ~90% accuracy — but the way they get there reveals fundamentally different philosophies about search. The real payoff of Bayesian optimization comes when evaluations are expensive: training a neural network for hours, running a simulation, or calling a paid API.&lt;/p&gt;

&lt;p&gt;If you've read the &lt;a href="https://sesen.ai/blog/genetic-algorithms-from-scratch" rel="noopener noreferrer"&gt;genetic algorithms post&lt;/a&gt;, you've already seen one approach to gradient-free optimization. Hyperparameter optimization is another — and Bayesian optimization is arguably the most elegant solution.&lt;/p&gt;

&lt;h2&gt;
  
  
  Quick Win: Run All Three Methods
&lt;/h2&gt;

&lt;p&gt;Click the badge to open the notebook and run everything yourself:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://colab.research.google.com/github/zhubarb/sesen_ai_ml_tutorials/blob/main/notebooks/optimisation/hyperparameter_optimization.ipynb" rel="noopener noreferrer"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fcolab.research.google.com%2Fassets%2Fcolab-badge.svg" alt="Open In Colab" width="117" height="20"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3&gt;
  
  
  Setup: The Problem
&lt;/h3&gt;

&lt;p&gt;We'll classify a synthetic dataset with 2,000 samples, 20 features, and 4 classes — a moderately challenging multi-class problem:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn.datasets&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;make_classification&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn.ensemble&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;RandomForestClassifier&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn.model_selection&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;cross_val_score&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;GridSearchCV&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;RandomizedSearchCV&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;StratifiedKFold&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;metrics&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;skopt&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;gp_minimize&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;skopt.space&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Real&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Integer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Categorical&lt;/span&gt;

&lt;span class="c1"&gt;# Generate dataset (tuned to match the original Kaggle mobile price dataset's ~90% RF accuracy)
&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;make_classification&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;n_samples&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_features&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;20&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_informative&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;15&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;n_clusters_per_class&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;flip_y&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.01&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;n_classes&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;
&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; features, &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;unique&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; classes, &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; samples&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="c1"&gt;# 20 features, 4 classes, 2000 samples
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;All three methods will tune the same 4 hyperparameters with 5-fold stratified cross-validation:&lt;/p&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Hyperparameter&lt;/th&gt;
&lt;th&gt;Type&lt;/th&gt;
&lt;th&gt;Range&lt;/th&gt;
&lt;th&gt;What it controls&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;max_features&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;Continuous&lt;/td&gt;
&lt;td&gt;[0.1, 1.0]&lt;/td&gt;
&lt;td&gt;Fraction of features per split&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;n_estimators&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;Integer&lt;/td&gt;
&lt;td&gt;[100, 1000]&lt;/td&gt;
&lt;td&gt;Number of trees&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;min_samples_leaf&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;Integer&lt;/td&gt;
&lt;td&gt;[5, 25]&lt;/td&gt;
&lt;td&gt;Minimum samples in a leaf&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;criterion&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;Categorical&lt;/td&gt;
&lt;td&gt;{gini, entropy}&lt;/td&gt;
&lt;td&gt;Split quality measure&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;h3&gt;
  
  
  Method 1: Grid Search
&lt;/h3&gt;

&lt;p&gt;Grid search evaluates every combination on a predefined grid:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;param_grid&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;n_estimators&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;200&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;400&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;600&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;800&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;criterion&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;gini&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;entropy&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;min_samples_leaf&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;20&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;max_features&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mf"&gt;0.3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.8&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;span class="p"&gt;}&lt;/span&gt;

&lt;span class="n"&gt;grid_search&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;GridSearchCV&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;estimator&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="nc"&gt;RandomForestClassifier&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_jobs&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
    &lt;span class="n"&gt;param_grid&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;param_grid&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;scoring&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;accuracy&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;cv&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="nc"&gt;StratifiedKFold&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_splits&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shuffle&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
    &lt;span class="n"&gt;verbose&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;n_jobs&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;
&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;grid_search&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Grid Search — Best accuracy: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;grid_search&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;best_score_&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Combinations evaluated: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;grid_search&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cv_results_&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;mean_test_score&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Best params: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;grid_search&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;best_params_&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;With &lt;code&gt;$4 \times 2 \times 3 \times 3 = 72$&lt;/code&gt; combinations and 5 folds each, that's 360 model fits.&lt;/p&gt;

&lt;h3&gt;
  
  
  Method 2: Random Search
&lt;/h3&gt;

&lt;p&gt;Random search samples 15 combinations uniformly at random from the ranges:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;param_distributions&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;n_estimators&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;arange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1001&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;criterion&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;gini&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;entropy&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;min_samples_leaf&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;arange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;26&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;max_features&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;linspace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="p"&gt;}&lt;/span&gt;

&lt;span class="n"&gt;random_search&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;RandomizedSearchCV&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;estimator&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="nc"&gt;RandomForestClassifier&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_jobs&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
    &lt;span class="n"&gt;param_distributions&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;param_distributions&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;n_iter&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;15&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;scoring&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;accuracy&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;cv&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="nc"&gt;StratifiedKFold&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_splits&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shuffle&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
    &lt;span class="n"&gt;verbose&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;n_jobs&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;
&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;random_search&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Random Search — Best accuracy: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;random_search&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;best_score_&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Combinations evaluated: 15&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Best params: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;random_search&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;best_params_&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Only 15 combinations, 75 model fits — a fraction of grid search.&lt;/p&gt;

&lt;h3&gt;
  
  
  Method 3: Bayesian Optimization (Gaussian Process)
&lt;/h3&gt;

&lt;p&gt;Bayesian optimization builds a probabilistic model of the objective and uses it to decide where to evaluate next:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;evaluate_params&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;5-fold stratified CV accuracy for a RandomForest configuration.
    Returns negative accuracy (since gp_minimize minimises).&lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
    &lt;span class="n"&gt;max_features&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_estimators&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;min_samples_leaf&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;criterion&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;params&lt;/span&gt;
    &lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;RandomForestClassifier&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
        &lt;span class="n"&gt;max_features&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;max_features&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;n_estimators&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;n_estimators&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;min_samples_leaf&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;min_samples_leaf&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;criterion&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;criterion&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;n_jobs&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;
    &lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;kf&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;StratifiedKFold&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_splits&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shuffle&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;accuracies&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;train_idx&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;val_idx&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;kf&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;split&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;train_idx&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;train_idx&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
        &lt;span class="n"&gt;preds&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;predict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;val_idx&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
        &lt;span class="n"&gt;accuracies&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;metrics&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;accuracy_score&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;val_idx&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;preds&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;accuracies&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;param_space&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;
    &lt;span class="nc"&gt;Real&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;prior&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;uniform&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;name&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;max_features&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
    &lt;span class="nc"&gt;Integer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;name&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;n_estimators&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
    &lt;span class="nc"&gt;Integer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;25&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;name&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;min_samples_leaf&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
    &lt;span class="nc"&gt;Categorical&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;gini&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;entropy&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;name&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;criterion&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="p"&gt;]&lt;/span&gt;

&lt;span class="n"&gt;result&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;gp_minimize&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;evaluate_params&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;dimensions&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;param_space&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;n_calls&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;15&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;n_random_starts&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;verbose&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;False&lt;/span&gt;
&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;best_params&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;dict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;zip&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;max_features&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;n_estimators&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;min_samples_leaf&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;criterion&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
    &lt;span class="n"&gt;result&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;
&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Bayesian (GP) — Best accuracy: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;result&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fun&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Evaluations: 15 (10 random + 5 guided)&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Best params: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;best_params&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Same 15 evaluations, but the last 5 are &lt;em&gt;informed&lt;/em&gt; by the GP model built from the first 10.&lt;/p&gt;

&lt;h3&gt;
  
  
  Results Comparison
&lt;/h3&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Method&lt;/th&gt;
&lt;th&gt;Best CV Accuracy&lt;/th&gt;
&lt;th&gt;Evaluations&lt;/th&gt;
&lt;th&gt;Strategy&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;Grid Search&lt;/td&gt;
&lt;td&gt;~90%&lt;/td&gt;
&lt;td&gt;72&lt;/td&gt;
&lt;td&gt;Exhaustive (predefined grid)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Random Search&lt;/td&gt;
&lt;td&gt;~90%&lt;/td&gt;
&lt;td&gt;15&lt;/td&gt;
&lt;td&gt;Uniform random sampling&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Bayesian (GP)&lt;/td&gt;
&lt;td&gt;~90%&lt;/td&gt;
&lt;td&gt;15 (10 random + 5 guided)&lt;/td&gt;
&lt;td&gt;Model-based sequential&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;p&gt;All three methods converge to approximately the same accuracy — around 90%. For this well-behaved 4-dimensional problem, the objective landscape is smooth enough that random search finds good regions quickly.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fvi15fh5j4a7pq47oqqa5.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fvi15fh5j4a7pq47oqqa5.webp" alt="Best cross-validation accuracy over successive evaluations for Grid Search, Random Search, and Bayesian Optimization" width="800" height="395"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The convergence chart tells the real story. Grid search plods through its predefined grid systematically. Random search jumps around but finds a good region early. Bayesian optimization starts random (first 10 points) then accelerates — the GP model narrows in on promising regions.&lt;/p&gt;

&lt;p&gt;The punchline: &lt;strong&gt;the methods differ most when evaluations are expensive.&lt;/strong&gt; For a Random Forest on 2,000 samples, each evaluation takes under a second and you can afford to be wasteful. Train a Transformer for 8 hours per evaluation, and the difference between 15 smart evaluations and 72 brute-force ones is the difference between two days and two weeks.&lt;/p&gt;

&lt;h2&gt;
  
  
  What Just Happened?
&lt;/h2&gt;

&lt;p&gt;All three methods solve the same problem — find the hyperparameter combination that maximises cross-validated accuracy — but they navigate the search space in fundamentally different ways.&lt;/p&gt;

&lt;h3&gt;
  
  
  Grid Search: The Cartesian Product
&lt;/h3&gt;

&lt;p&gt;Grid search evaluates every point on a regular lattice. If you specify 4 values for &lt;code&gt;n_estimators&lt;/code&gt;, 2 for &lt;code&gt;criterion&lt;/code&gt;, 3 for &lt;code&gt;min_samples_leaf&lt;/code&gt;, and 3 for &lt;code&gt;max_features&lt;/code&gt;, you get &lt;code&gt;$4 \times 2 \times 3 \times 3 = 72$&lt;/code&gt; combinations. It's the Cartesian product of your parameter lists.&lt;/p&gt;

&lt;p&gt;Think of a tourist exploring a city by visiting every intersection on the street grid. Thorough? Yes. Efficient? Not even close. The problem is that most intersections are uninteresting — the accuracy surface for a Random Forest is usually smooth, so neighbouring grid points give nearly identical results.&lt;/p&gt;

&lt;p&gt;The deeper issue is the &lt;strong&gt;curse of dimensionality&lt;/strong&gt;. In &lt;code&gt;$d$&lt;/code&gt; dimensions with &lt;code&gt;$k$&lt;/code&gt; values per dimension, the grid has &lt;code&gt;$k^d$&lt;/code&gt; points. With 10 values per dimension: 2D gives 100 points (fine), 4D gives 10,000 (slow), 10D gives 10 billion (impossible). Grid search is fundamentally limited to low-dimensional problems with coarse grids.&lt;/p&gt;

&lt;h3&gt;
  
  
  Random Search: The Bergstra-Bengio Insight
&lt;/h3&gt;

&lt;p&gt;Random search replaces the grid with uniform random samples. This seems like it should be worse — throwing darts blindfolded. But Bergstra and Bengio (2012) showed it's almost always &lt;em&gt;better&lt;/em&gt; than grid search for the same computational budget.&lt;/p&gt;

&lt;p&gt;The insight is elegant. Suppose only 2 of your 4 hyperparameters actually matter (a common situation). On a &lt;code&gt;$4 \times 4$&lt;/code&gt; grid in 2D, you get 16 unique points — but only 4 unique values per dimension. Random search with 16 points gives you 16 unique values per dimension. You're covering the important dimensions far more densely.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F9mbi2tmelamw338o9ucw.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F9mbi2tmelamw338o9ucw.webp" alt="Three-panel comparison showing how Grid Search, Random Search, and Bayesian Optimization sample a 2D hyperparameter space" width="800" height="281"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The figure makes this vivid. Grid search evaluates 9 points, but only 3 unique values along each axis. Random search with 9 points explores 9 unique values per axis. Bayesian optimization clusters its samples in the high-accuracy region (top-right), exploring broadly first and then exploiting the best area.&lt;/p&gt;

&lt;h3&gt;
  
  
  Bayesian Optimization: Learning from History
&lt;/h3&gt;

&lt;p&gt;Grid and random search are &lt;strong&gt;memoryless&lt;/strong&gt; — each evaluation is independent of the others. The 15th evaluation in random search knows nothing about the first 14.&lt;/p&gt;

&lt;p&gt;Bayesian optimization is fundamentally different. It maintains a &lt;strong&gt;model&lt;/strong&gt; of the objective function — a Gaussian process (GP) that predicts what the accuracy will be at any point in the hyperparameter space, along with an uncertainty estimate. After each evaluation, the model updates, and an &lt;strong&gt;acquisition function&lt;/strong&gt; decides where to evaluate next.&lt;/p&gt;

&lt;p&gt;The three components:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Surrogate model (Gaussian Process)&lt;/strong&gt; — A probabilistic model that interpolates between observed points. At observed points, uncertainty is zero. Far from observations, uncertainty is high. Think of it as drawing a smooth surface through scattered data points, with error bars that widen between points.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Acquisition function&lt;/strong&gt; — A cheap-to-evaluate function that balances &lt;strong&gt;exploration&lt;/strong&gt; (high uncertainty) and &lt;strong&gt;exploitation&lt;/strong&gt; (high predicted value). The most common is &lt;strong&gt;Expected Improvement (EI)&lt;/strong&gt;: "how much better than the current best do I expect this point to be?"&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Optimize the acquisition function&lt;/strong&gt; — Find the point that maximises EI (a standard optimization problem, but cheap because the GP is fast to evaluate), then run the expensive evaluation there.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;This loop — fit GP, maximize acquisition, evaluate, repeat — is what makes Bayesian optimization &lt;em&gt;sequential&lt;/em&gt; and &lt;em&gt;adaptive&lt;/em&gt;. If you've read the &lt;a href="https://sesen.ai/blog/from-mle-to-bayesian-inference" rel="noopener noreferrer"&gt;Bayesian inference post&lt;/a&gt;, you'll recognise the pattern: start with a prior (the GP), observe data, update to a posterior, and use the posterior to make decisions. The GP generalises this from updating parameter estimates to updating an entire &lt;em&gt;function&lt;/em&gt; estimate.&lt;/p&gt;

&lt;h2&gt;
  
  
  Going Deeper
&lt;/h2&gt;

&lt;h3&gt;
  
  
  The GP Surrogate in Action
&lt;/h3&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fsw7mvofuvu4lgpbsrevl.gif" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fsw7mvofuvu4lgpbsrevl.gif" alt="A Gaussian Process surrogate model fitting observations one-by-one, with the acquisition function (Expected Improvement) highlighted below — watch how uncertainty shrinks near observations and the acquisition peak shifts to promising regions" width="1043" height="657"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The animation shows Bayesian optimization on a 1D function. The blue line is the true (unknown) objective. The black line is the GP's mean prediction, with the shaded region showing the 95% confidence interval. The green area below is the Expected Improvement — it peaks where the model expects to find values better than the current best.&lt;/p&gt;

&lt;p&gt;Watch how the GP evolves:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Early frames:&lt;/strong&gt; Few observations, wide uncertainty, EI spreads across unexplored regions (exploration)&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Middle frames:&lt;/strong&gt; The GP starts to capture the function's shape, uncertainty shrinks near data points&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Late frames:&lt;/strong&gt; EI concentrates around the global optimum as the model becomes confident elsewhere (exploitation)&lt;/li&gt;
&lt;/ul&gt;

&lt;h3&gt;
  
  
  GP Predictive Distribution
&lt;/h3&gt;

&lt;p&gt;At any unobserved point &lt;code&gt;$\mathbf{x}_*$&lt;/code&gt;, the GP predicts a Gaussian distribution:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Df%28%255Cmathbf%257Bx%257D_%2A%29%2520%255Csim%2520%255Cmathcal%257BN%257D%28%255Cmu%28%255Cmathbf%257Bx%257D_%2A%29%252C%2520%255Csigma%255E2%28%255Cmathbf%257Bx%257D_%2A%29%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Df%28%255Cmathbf%257Bx%257D_%2A%29%2520%255Csim%2520%255Cmathcal%257BN%257D%28%255Cmu%28%255Cmathbf%257Bx%257D_%2A%29%252C%2520%255Csigma%255E2%28%255Cmathbf%257Bx%257D_%2A%29%29" alt="equation" width="269" height="27"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Where:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;code&gt;$\mu(\mathbf{x}_*)$&lt;/code&gt; — the mean prediction, interpolating nearby observations&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;$\sigma^2(\mathbf{x}_*)$&lt;/code&gt; — the predictive variance: high far from data, zero at observed points&lt;/li&gt;
&lt;li&gt;The predictions are conditioned on all previous observations &lt;code&gt;$\{(\mathbf{x}_i, y_i)\}_{i=1}^n$&lt;/code&gt;
&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;The mean and variance are computed in closed form from the kernel function (typically Matérn 5/2 in skopt):&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Cmu%28%255Cmathbf%257Bx%257D_%2A%29%2520%253D%2520%255Cmathbf%257Bk%257D_%2A%255ET%2520%28%255Cmathbf%257BK%257D%2520%252B%2520%255Csigma_n%255E2%2520%255Cmathbf%257BI%257D%29%255E%257B-1%257D%2520%255Cmathbf%257By%257D" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Cmu%28%255Cmathbf%257Bx%257D_%2A%29%2520%253D%2520%255Cmathbf%257Bk%257D_%2A%255ET%2520%28%255Cmathbf%257BK%257D%2520%252B%2520%255Csigma_n%255E2%2520%255Cmathbf%257BI%257D%29%255E%257B-1%257D%2520%255Cmathbf%257By%257D" alt="equation" width="265" height="28"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Csigma%255E2%28%255Cmathbf%257Bx%257D_%2A%29%2520%253D%2520k%28%255Cmathbf%257Bx%257D_%2A%252C%2520%255Cmathbf%257Bx%257D_%2A%29%2520-%2520%255Cmathbf%257Bk%257D_%2A%255ET%2520%28%255Cmathbf%257BK%257D%2520%252B%2520%255Csigma_n%255E2%2520%255Cmathbf%257BI%257D%29%255E%257B-1%257D%2520%255Cmathbf%257Bk%257D_%2A" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Csigma%255E2%28%255Cmathbf%257Bx%257D_%2A%29%2520%253D%2520k%28%255Cmathbf%257Bx%257D_%2A%252C%2520%255Cmathbf%257Bx%257D_%2A%29%2520-%2520%255Cmathbf%257Bk%257D_%2A%255ET%2520%28%255Cmathbf%257BK%257D%2520%252B%2520%255Csigma_n%255E2%2520%255Cmathbf%257BI%257D%29%255E%257B-1%257D%2520%255Cmathbf%257Bk%257D_%2A" alt="equation" width="406" height="28"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Where &lt;code&gt;$\mathbf{K}$&lt;/code&gt; is the kernel matrix between all observed points, &lt;code&gt;$\mathbf{k}_*$&lt;/code&gt; is the kernel vector between &lt;code&gt;$\mathbf{x}_*$&lt;/code&gt; and observed points, and &lt;code&gt;$\sigma_n^2$&lt;/code&gt; is the observation noise.&lt;/p&gt;

&lt;p&gt;In plain English: the GP prediction at a new point is a &lt;strong&gt;weighted average&lt;/strong&gt; of observed values, where the weights come from how "similar" (in kernel space) the new point is to each observation.&lt;/p&gt;

&lt;h3&gt;
  
  
  Acquisition Functions: Deciding Where to Look Next
&lt;/h3&gt;

&lt;p&gt;The acquisition function turns the GP's prediction into a decision: where should we evaluate next? Three common choices:&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Expected Improvement (EI)&lt;/strong&gt; — the default in &lt;code&gt;skopt&lt;/code&gt;:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Ctext%257BEI%257D%28%255Cmathbf%257Bx%257D%29%2520%253D%2520%255Cmathbb%257BE%257D%255Cleft%255B%255Cmax%28f%28%255Cmathbf%257Bx%257D%29%2520-%2520f%28%255Cmathbf%257Bx%257D%255E%252B%29%252C%25200%29%255Cright%255D" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Ctext%257BEI%257D%28%255Cmathbf%257Bx%257D%29%2520%253D%2520%255Cmathbb%257BE%257D%255Cleft%255B%255Cmax%28f%28%255Cmathbf%257Bx%257D%29%2520-%2520f%28%255Cmathbf%257Bx%257D%255E%252B%29%252C%25200%29%255Cright%255D" alt="equation" width="357" height="30"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Where &lt;code&gt;$f(\mathbf{x}^+)$&lt;/code&gt; is the best value observed so far. EI asks: "in expectation, how much will this point improve on the current best?" Points with high predicted mean &lt;em&gt;and&lt;/em&gt; high uncertainty score well — this naturally balances exploitation and exploration.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Lower Confidence Bound (LCB)&lt;/strong&gt;:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Ctext%257BLCB%257D%28%255Cmathbf%257Bx%257D%29%2520%253D%2520%255Cmu%28%255Cmathbf%257Bx%257D%29%2520-%2520%255Ckappa%2520%255Ccdot%2520%255Csigma%28%255Cmathbf%257Bx%257D%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Ctext%257BLCB%257D%28%255Cmathbf%257Bx%257D%29%2520%253D%2520%255Cmu%28%255Cmathbf%257Bx%257D%29%2520-%2520%255Ckappa%2520%255Ccdot%2520%255Csigma%28%255Cmathbf%257Bx%257D%29" alt="equation" width="273" height="25"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Minimize the mean minus a multiple of the standard deviation. The parameter &lt;code&gt;$\kappa$&lt;/code&gt; controls the exploration–exploitation trade-off directly: higher &lt;code&gt;$\kappa$&lt;/code&gt; means more exploration.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Probability of Improvement (PI)&lt;/strong&gt;:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Ctext%257BPI%257D%28%255Cmathbf%257Bx%257D%29%2520%253D%2520P%28f%28%255Cmathbf%257Bx%257D%29%2520%253E%2520f%28%255Cmathbf%257Bx%257D%255E%252B%29%29%2520%253D%2520%255CPhi%255Cleft%28%255Cfrac%257B%255Cmu%28%255Cmathbf%257Bx%257D%29%2520-%2520f%28%255Cmathbf%257Bx%257D%255E%252B%29%257D%257B%255Csigma%28%255Cmathbf%257Bx%257D%29%257D%255Cright%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Ctext%257BPI%257D%28%255Cmathbf%257Bx%257D%29%2520%253D%2520P%28f%28%255Cmathbf%257Bx%257D%29%2520%253E%2520f%28%255Cmathbf%257Bx%257D%255E%252B%29%29%2520%253D%2520%255CPhi%255Cleft%28%255Cfrac%257B%255Cmu%28%255Cmathbf%257Bx%257D%29%2520-%2520f%28%255Cmathbf%257Bx%257D%255E%252B%29%257D%257B%255Csigma%28%255Cmathbf%257Bx%257D%29%257D%255Cright%29" alt="equation" width="508" height="60"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The probability that a point improves on the current best. Simpler than EI, but tends to exploit too aggressively — it doesn't care &lt;em&gt;how much&lt;/em&gt; better a point might be, only whether it's better at all.&lt;/p&gt;

&lt;h3&gt;
  
  
  Exploration vs Exploitation
&lt;/h3&gt;

&lt;p&gt;This tension appears everywhere in machine learning. In &lt;a href="https://sesen.ai/blog/q-learning-frozen-lake-from-scratch" rel="noopener noreferrer"&gt;Q-learning&lt;/a&gt;, epsilon-greedy balances trying new actions (exploration) with choosing the best-known action (exploitation). In &lt;a href="https://sesen.ai/blog/mcmc-metropolis-hastings-island-hopping-guide" rel="noopener noreferrer"&gt;MCMC&lt;/a&gt;, the proposal distribution must explore the parameter space while spending time in high-probability regions.&lt;/p&gt;

&lt;p&gt;Bayesian optimization handles this &lt;em&gt;automatically&lt;/em&gt; through the acquisition function. EI naturally favours:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;Points with high predicted accuracy (exploitation)&lt;/li&gt;
&lt;li&gt;Points with high uncertainty (exploration)&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;No manual schedule needed — the GP's uncertainty estimates do the work.&lt;/p&gt;

&lt;h3&gt;
  
  
  When NOT to Use Bayesian Optimization
&lt;/h3&gt;

&lt;p&gt;Bayesian optimization isn't always the right tool:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Cheap evaluations&lt;/strong&gt; — If each evaluation takes seconds (like our Random Forest), random search with 100 iterations is simpler and nearly as effective&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;High dimensions&lt;/strong&gt; (&lt;code&gt;$d &amp;gt; 20$&lt;/code&gt;) — GPs scale poorly with dimensionality. The kernel becomes uninformative and the acquisition function has too many local optima&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Massive parallelism&lt;/strong&gt; — If you have 1,000 GPUs, you can evaluate 1,000 random configurations simultaneously. Bayesian optimization is inherently sequential (each evaluation depends on all previous ones)&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Discrete/conditional spaces&lt;/strong&gt; — GPs assume smooth, continuous objectives. Deeply nested conditional hyperparameters (e.g., "layer type" → "layer-specific params") are better handled by tree-based methods like Optuna's TPE&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;&lt;strong&gt;Rule of thumb:&lt;/strong&gt; Use Bayesian optimization when each evaluation costs minutes to hours (neural network training, simulation runs, expensive API calls) and you're in 5–15 dimensions.&lt;/p&gt;

&lt;h3&gt;
  
  
  Hyperparameters
&lt;/h3&gt;

&lt;p&gt;All values come directly from the original code in &lt;code&gt;quant_code/python/HyperParameter_Optimization/src/&lt;/code&gt;:&lt;/p&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Parameter&lt;/th&gt;
&lt;th&gt;Source file&lt;/th&gt;
&lt;th&gt;Value&lt;/th&gt;
&lt;th&gt;Why&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;max_features&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;&lt;code&gt;gp_min_optim.py&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;Real(0.1, 1.0)&lt;/td&gt;
&lt;td&gt;Fraction of features per split&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;n_estimators&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;&lt;code&gt;gp_min_optim.py&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;Integer(100, 1000)&lt;/td&gt;
&lt;td&gt;Number of trees in the forest&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;min_samples_leaf&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;&lt;code&gt;gp_min_optim.py&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;Integer(5, 25)&lt;/td&gt;
&lt;td&gt;Minimum leaf samples (regularisation)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;criterion&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;&lt;code&gt;grid_n_random_search.py&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;{gini, entropy}&lt;/td&gt;
&lt;td&gt;Split quality metric&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;n_calls&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;&lt;code&gt;gp_min_optim.py&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;15&lt;/td&gt;
&lt;td&gt;Total evaluations for GP&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;n_random_starts&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;&lt;code&gt;gp_min_optim.py&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;10&lt;/td&gt;
&lt;td&gt;Initial random exploration&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;cv&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;all files&lt;/td&gt;
&lt;td&gt;5-fold stratified&lt;/td&gt;
&lt;td&gt;Evaluation protocol&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;h2&gt;
  
  
  Deep Dive: The Papers
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Bergstra &amp;amp; Bengio (2012) — Random Search for Hyper-Parameter Optimization
&lt;/h3&gt;

&lt;p&gt;The key paper that changed how practitioners think about hyperparameter search. &lt;a href="https://jmlr.org/papers/v13/bergstra12a.html" rel="noopener noreferrer"&gt;&lt;em&gt;Random Search for Hyper-Parameter Optimization&lt;/em&gt;&lt;/a&gt; (JMLR) demonstrated that random search is more efficient than grid search for most practical problems.&lt;/p&gt;

&lt;p&gt;The core theorem is deceptively simple. Define the &lt;strong&gt;effective dimensionality&lt;/strong&gt; of a search problem as the number of hyperparameters that significantly affect performance. Bergstra and Bengio showed:&lt;/p&gt;

&lt;blockquote&gt;
&lt;p&gt;"Random search is more efficient than grid search because it allows each trial to explore a different value of every hyperparameter. For problems with low effective dimensionality, this results in dramatically better coverage of the important dimensions."&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;Their famous figure (reproduced in our sampling comparison above) shows a 2D search where only one dimension matters. Grid search with 9 points wastes 6 of them evaluating the same 3 values of the important dimension. Random search with 9 points gives 9 unique values along the important dimension.&lt;/p&gt;

&lt;p&gt;The practical implication: for the same computational budget, random search achieves equal or better results than grid search in virtually all cases. There is essentially no scenario where grid search is preferable.&lt;/p&gt;

&lt;h3&gt;
  
  
  Snoek, Larochelle &amp;amp; Adams (2012) — Practical Bayesian Optimization
&lt;/h3&gt;

&lt;p&gt;&lt;a href="https://papers.nips.cc/paper/2012/hash/05311655a15b75fab86956663e1819cd-Abstract.html" rel="noopener noreferrer"&gt;&lt;em&gt;Practical Bayesian Optimization of Machine Learning Algorithms&lt;/em&gt;&lt;/a&gt; (NeurIPS) introduced the Spearmint system and demonstrated that Bayesian optimization could match or beat human experts at tuning neural networks.&lt;/p&gt;

&lt;p&gt;Their algorithm:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;Input: Search space S, objective function f, budget N
1. Initialize with n₀ random evaluations
2. For i = n₀+1 to N:
   a. Fit GP to all observations {(x₁,y₁), ..., (xᵢ₋₁,yᵢ₋₁)}
   b. Find xᵢ = argmax_x EI(x) using the GP
   c. Evaluate yᵢ = f(xᵢ)
3. Return x with best observed y
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This is exactly what &lt;code&gt;skopt.gp_minimize&lt;/code&gt; implements. The key contributions beyond earlier work:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Automatic Relevance Determination (ARD) kernels&lt;/strong&gt; — learns which hyperparameters matter most&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Integrated acquisition function&lt;/strong&gt; — marginalises over GP hyperparameters rather than fixing them&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Pending evaluations&lt;/strong&gt; — supports parallel evaluations through "fantasised" outcomes&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;Snoek et al. demonstrated that Bayesian optimization with 30 evaluations outperformed a human expert who had months to tune the same neural networks. The quote that launched a thousand AutoML papers:&lt;/p&gt;

&lt;blockquote&gt;
&lt;p&gt;"Bayesian optimization spent 2 minutes of overhead in exchange for saving the researchers days of manual tuning."&lt;/p&gt;
&lt;/blockquote&gt;

&lt;h3&gt;
  
  
  Mockus et al. (1978) &amp;amp; Jones et al. (1998) — The Origins
&lt;/h3&gt;

&lt;p&gt;Bayesian optimization predates machine learning. &lt;a href="https://link.springer.com/chapter/10.1007/3-540-07165-2_55" rel="noopener noreferrer"&gt;Mockus, Tiesis, and Žilinskas (1978)&lt;/a&gt; formalised the idea of using a Bayesian model to guide sequential optimization in their work on &lt;em&gt;Bayesian Methods for Seeking the Extremum&lt;/em&gt;. They introduced the Expected Improvement criterion, proving it is the optimal policy for a one-step lookahead under certain assumptions.&lt;/p&gt;

&lt;p&gt;Two decades later, &lt;a href="https://link.springer.com/article/10.1023/A:1008306431147" rel="noopener noreferrer"&gt;Jones, Schonlau, and Welch (1998)&lt;/a&gt; brought these ideas to engineering design optimization with the &lt;strong&gt;EGO&lt;/strong&gt; (Efficient Global Optimization) algorithm. Their paper &lt;em&gt;Efficient Global Optimization of Expensive Black-Box Functions&lt;/em&gt; established the GP + EI framework that Snoek et al. later applied to ML hyperparameters.&lt;/p&gt;

&lt;h3&gt;
  
  
  Historical Timeline
&lt;/h3&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Mockus et al. (1978)&lt;/strong&gt; — Bayesian optimization and Expected Improvement&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Jones et al. (1998)&lt;/strong&gt; — EGO: GP + EI for engineering design&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Hutter et al. (2011)&lt;/strong&gt; — &lt;a href="https://ml.informatik.uni-freiburg.de/papers/11-LION5-SMAC.pdf" rel="noopener noreferrer"&gt;SMAC&lt;/a&gt;: random forests as surrogate instead of GPs (scales to high dimensions)&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Bergstra et al. (2011)&lt;/strong&gt; — &lt;a href="https://papers.nips.cc/paper/2011/hash/86e8f7ab32cfd12577bc2619bc635690-Abstract.html" rel="noopener noreferrer"&gt;Hyperopt&lt;/a&gt; and TPE: tree-structured Parzen estimators (handles conditional spaces)&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Bergstra &amp;amp; Bengio (2012)&lt;/strong&gt; — The random search paper&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Snoek et al. (2012)&lt;/strong&gt; — Practical Bayesian optimization (Spearmint)&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Akiba et al. (2019)&lt;/strong&gt; — &lt;a href="https://arxiv.org/abs/1907.10902" rel="noopener noreferrer"&gt;Optuna&lt;/a&gt;: define-by-run API, pruning, modern TPE (now the most popular framework)&lt;/li&gt;
&lt;/ul&gt;

&lt;h3&gt;
  
  
  Modern Alternatives: Optuna and Hyperopt
&lt;/h3&gt;

&lt;p&gt;The original source code also includes implementations in Optuna and Hyperopt. Both use &lt;strong&gt;Tree-structured Parzen Estimators (TPE)&lt;/strong&gt; instead of Gaussian Processes:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Optuna (from optuna_optim.py) — define-by-run API
&lt;/span&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;optuna&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;objective&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;trial&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;criterion&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;trial&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;suggest_categorical&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;criterion&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;gini&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;entropy&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
    &lt;span class="n"&gt;n_estimators&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;trial&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;suggest_int&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;n_estimators&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1000&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;min_samples_leaf&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;trial&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;suggest_int&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;min_samples_leaf&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;25&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;max_features&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;trial&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;suggest_float&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;max_features&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;RandomForestClassifier&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
        &lt;span class="n"&gt;n_estimators&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;n_estimators&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;criterion&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;criterion&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;max_features&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;max_features&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;min_samples_leaf&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;min_samples_leaf&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;n_jobs&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;
    &lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;kf&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;StratifiedKFold&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_splits&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shuffle&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;scores&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;train_idx&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;val_idx&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;kf&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;split&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;train_idx&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;train_idx&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
        &lt;span class="n"&gt;scores&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;metrics&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;accuracy_score&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;val_idx&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;predict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;val_idx&lt;/span&gt;&lt;span class="p"&gt;])))&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;scores&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;study&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;optuna&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;create_study&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;direction&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;minimize&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;study&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;optimize&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;objective&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_trials&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;15&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;TPE models &lt;code&gt;$P(\mathbf{x} | y &amp;lt; y^*)$&lt;/code&gt; and &lt;code&gt;$P(\mathbf{x} | y \geq y^*)$&lt;/code&gt; separately (two Parzen estimators), then selects points that maximise the ratio. This handles conditional and discrete parameters more naturally than GPs.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;When to use which:&lt;/strong&gt;&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;skopt (GP)&lt;/strong&gt; — smooth, low-dimensional spaces; small evaluation budgets&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Optuna (TPE)&lt;/strong&gt; — large search spaces; conditional parameters; early pruning of bad trials&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Hyperopt (TPE)&lt;/strong&gt; — similar to Optuna, but older API; still widely used&lt;/li&gt;
&lt;/ul&gt;

&lt;h3&gt;
  
  
  Further Reading
&lt;/h3&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://jmlr.org/papers/v13/bergstra12a.html" rel="noopener noreferrer"&gt;Bergstra &amp;amp; Bengio (2012)&lt;/a&gt; — &lt;em&gt;Random Search for Hyper-Parameter Optimization&lt;/em&gt; — the paper that made random search the default&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://papers.nips.cc/paper/2012/hash/05311655a15b75fab86956663e1819cd-Abstract.html" rel="noopener noreferrer"&gt;Snoek et al. (2012)&lt;/a&gt; — &lt;em&gt;Practical Bayesian Optimization of Machine Learning Algorithms&lt;/em&gt; — the Spearmint paper&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://link.springer.com/article/10.1023/A:1008306431147" rel="noopener noreferrer"&gt;Jones et al. (1998)&lt;/a&gt; — &lt;em&gt;Efficient Global Optimization of Expensive Black-Box Functions&lt;/em&gt; — the EGO algorithm&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://arxiv.org/abs/1907.10902" rel="noopener noreferrer"&gt;Akiba et al. (2019)&lt;/a&gt; — &lt;em&gt;Optuna: A Next-generation Hyperparameter Optimization Framework&lt;/em&gt; — the most popular modern framework&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://ieeexplore.ieee.org/document/7352306" rel="noopener noreferrer"&gt;Shahriari et al. (2016)&lt;/a&gt; — &lt;em&gt;Taking the Human Out of the Loop: A Review of Bayesian Optimization&lt;/em&gt; — comprehensive survey&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Interactive Tools
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/regression-playground" rel="noopener noreferrer"&gt;Regression Playground&lt;/a&gt; — Experiment with model complexity and see how different hyperparameters affect the fit&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/overfitting-explorer" rel="noopener noreferrer"&gt;Overfitting Explorer&lt;/a&gt; — Visualise the bias-variance tradeoff that hyperparameter tuning aims to optimise&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Related Posts
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/genetic-algorithms-from-scratch" rel="noopener noreferrer"&gt;Genetic Algorithms from Scratch&lt;/a&gt; — Another gradient-free optimizer for black-box functions, using evolutionary strategies instead of surrogate models&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/from-mle-to-bayesian-inference" rel="noopener noreferrer"&gt;From MLE to Bayesian Inference&lt;/a&gt; — The Gaussian Process generalises Bayesian updating from parameters to entire functions&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/mcmc-metropolis-hastings-island-hopping-guide" rel="noopener noreferrer"&gt;MCMC Island Hopping&lt;/a&gt; — Exploration vs exploitation in sampling — the same tension that drives acquisition functions&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/maximum-likelihood-estimation-from-scratch" rel="noopener noreferrer"&gt;MLE from Scratch&lt;/a&gt; — Cross-validated accuracy is a likelihood proxy: we're maximising the probability that the model explains held-out data&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Try It Yourself
&lt;/h2&gt;

&lt;p&gt;The &lt;a href="https://colab.research.google.com/github/zhubarb/sesen_ai_ml_tutorials/blob/main/notebooks/optimisation/hyperparameter_optimization.ipynb" rel="noopener noreferrer"&gt;interactive notebook&lt;/a&gt; includes exercises:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;More iterations&lt;/strong&gt; — Increase &lt;code&gt;n_calls&lt;/code&gt; to 50 for the Bayesian optimizer. Does the extra budget find meaningfully better configurations, or does it plateau?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Higher dimensions&lt;/strong&gt; — Add &lt;code&gt;max_depth&lt;/code&gt; (Integer, 5–50) and &lt;code&gt;min_samples_split&lt;/code&gt; (Integer, 2–20) as hyperparameters. How do the methods scale with 6 dimensions?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Optuna comparison&lt;/strong&gt; — Run the Optuna TPE sampler alongside GP-based optimization. Compare convergence curves. Does TPE find good configurations faster?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Acquisition function sweep&lt;/strong&gt; — Try &lt;code&gt;acq_func='LCB'&lt;/code&gt; and &lt;code&gt;acq_func='PI'&lt;/code&gt; in &lt;code&gt;gp_minimize&lt;/code&gt;. How does the exploration–exploitation balance change?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Simulated expensive evaluations&lt;/strong&gt; — Add &lt;code&gt;time.sleep(2)&lt;/code&gt; inside &lt;code&gt;evaluate_params&lt;/code&gt; to simulate expensive evaluations. Now the wall-clock difference between 15 evaluations (Bayesian) and 72 (grid) becomes tangible&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;The three methods represent a progression in sophistication: grid search makes no assumptions, random search exploits the low effective dimensionality of most problems, and Bayesian optimization builds a model to make each evaluation count. For cheap models like Random Forests, random search is usually sufficient. But when each evaluation costs real time — training a Transformer, running a physics simulation, querying a paid API — the model-based approach of Bayesian optimization can save days of compute. The cost of fitting a GP is negligible compared to hours of training, and even 5 guided evaluations out of 15 total can find configurations that random search might need 50 iterations to reach.&lt;/p&gt;

&lt;h2&gt;
  
  
  Frequently Asked Questions
&lt;/h2&gt;

&lt;h3&gt;
  
  
  What is hyperparameter optimisation and why does it matter?
&lt;/h3&gt;

&lt;p&gt;Hyperparameters are settings you choose before training a model (learning rate, tree depth, regularisation strength) that cannot be learned from the data. Poor choices lead to underfitting or overfitting. Systematic optimisation finds the best combination, often improving model performance significantly compared to manual tuning.&lt;/p&gt;

&lt;h3&gt;
  
  
  When should I use random search instead of grid search?
&lt;/h3&gt;

&lt;p&gt;Almost always. Random search is more efficient because it explores more unique values of each hyperparameter. Grid search wastes evaluations on unimportant parameter combinations, especially when only one or two hyperparameters actually matter. Random search achieves the same or better results with fewer evaluations in most practical settings.&lt;/p&gt;

&lt;h3&gt;
  
  
  What is Bayesian optimisation and when is it worth the overhead?
&lt;/h3&gt;

&lt;p&gt;Bayesian optimisation builds a probabilistic model (typically a Gaussian process) of the objective function and uses it to choose the next hyperparameters to evaluate intelligently. It is worth the overhead when each evaluation is expensive (training takes hours) and the search space has fewer than about 20 dimensions. For cheap evaluations, random search is often sufficient.&lt;/p&gt;

&lt;h3&gt;
  
  
  How many hyperparameter evaluations should I run?
&lt;/h3&gt;

&lt;p&gt;For random search, 60 evaluations is usually enough to find a good configuration if only 1-3 hyperparameters matter (there is a 95% chance of sampling within the top 5% of the space). For Bayesian optimisation, 20-50 evaluations often suffice due to the intelligent search strategy. Scale up for higher-dimensional spaces.&lt;/p&gt;

&lt;h3&gt;
  
  
  Should I use cross-validation during hyperparameter optimisation?
&lt;/h3&gt;

&lt;p&gt;Yes. Evaluating hyperparameters on a single train-test split is noisy and can lead to overfitting the validation set. K-fold cross-validation gives more reliable performance estimates. Use 5-fold as a default, or 3-fold if training is expensive. Always keep a final held-out test set that is never used during optimisation.&lt;/p&gt;

</description>
      <category>optimisation</category>
      <category>supervisedlearning</category>
      <category>bayesian</category>
    </item>
    <item>
      <title>Policy Gradients: REINFORCE from Scratch with NumPy</title>
      <dc:creator>Berkan Sesen</dc:creator>
      <pubDate>Wed, 08 Apr 2026 07:54:50 +0000</pubDate>
      <link>https://dev.to/berkan_sesen/policy-gradients-reinforce-from-scratch-with-numpy-4e6j</link>
      <guid>https://dev.to/berkan_sesen/policy-gradients-reinforce-from-scratch-with-numpy-4e6j</guid>
      <description>&lt;p&gt;In the &lt;a href="https://sesen.ai/blog/deep-q-networks-experience-replay-target-networks" rel="noopener noreferrer"&gt;DQN post&lt;/a&gt;, we trained a neural network to estimate Q-values and then picked the best action with argmax. That works when the action space is discrete — push left or push right. But what if you need to control a robotic arm with continuous joint angles, or steer a car with a continuous throttle? You can't argmax over infinity.&lt;/p&gt;

&lt;p&gt;Policy gradient methods flip the approach: instead of learning a value function and deriving a policy, we &lt;strong&gt;directly parameterise the policy&lt;/strong&gt; and optimise it via gradient ascent. The network outputs action probabilities, we sample from them, and we nudge the parameters toward actions that led to high rewards. No Q-values, no argmax, no experience replay — just a policy, a gradient, and a reward signal.&lt;/p&gt;

&lt;p&gt;By the end of this post, you'll implement the REINFORCE algorithm entirely from scratch in NumPy — including the forward pass, backpropagation, and RMSProp optimiser — and train it to balance CartPole. The entire implementation is about 100 lines. No PyTorch, no TensorFlow, just NumPy and the chain rule.&lt;/p&gt;

&lt;h2&gt;
  
  
  Quick Win: Run the Algorithm
&lt;/h2&gt;

&lt;p&gt;Let's see REINFORCE in action. Click the badge to open the interactive notebook:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://colab.research.google.com/github/zhubarb/sesen_ai_ml_tutorials/blob/main/notebooks/reinforcement-learning/policy_gradient_cartpole.ipynb" rel="noopener noreferrer"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fcolab.research.google.com%2Fassets%2Fcolab-badge.svg" alt="Open In Colab" width="117" height="20"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fzmvsd6kyyu3fqxtoret6.gif" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fzmvsd6kyyu3fqxtoret6.gif" alt="REINFORCE learning to balance CartPole — the 100-episode rolling average climbs from ~25 to ~490, converging within about 3,000 episodes" width="800" height="400"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The animation shows the agent converging to the maximum score of 500 within about 3,000 episodes, using nothing but NumPy. No deep learning framework, no replay buffer — just a policy, a gradient, and RMSProp.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;gymnasium&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;

&lt;span class="c1"&gt;# --- Hyperparameters ---
# Original code (CartPole-v0, max 200 steps): H=100, lr=1e-4, gamma=0.95, batch_size=5
# Adapted for CartPole-v1 (max 500 steps): higher gamma and learning rate
&lt;/span&gt;&lt;span class="n"&gt;H&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;100&lt;/span&gt;              &lt;span class="c1"&gt;# hidden layer neurons
&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;5&lt;/span&gt;       &lt;span class="c1"&gt;# episodes per parameter update
&lt;/span&gt;&lt;span class="n"&gt;learning_rate&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;1e-3&lt;/span&gt; &lt;span class="c1"&gt;# RMSProp learning rate (original: 1e-4, raised for longer episodes)
&lt;/span&gt;&lt;span class="n"&gt;gamma&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;0.99&lt;/span&gt;         &lt;span class="c1"&gt;# discount factor (original: 0.95, raised for longer horizon)
&lt;/span&gt;&lt;span class="n"&gt;decay_rate&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;0.99&lt;/span&gt;    &lt;span class="c1"&gt;# RMSProp decay
&lt;/span&gt;&lt;span class="n"&gt;epsilon&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;1e-5&lt;/span&gt;       &lt;span class="c1"&gt;# RMSProp epsilon
&lt;/span&gt;
&lt;span class="c1"&gt;# --- Network functions ---
&lt;/span&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;sigmoid&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="mf"&gt;1.0&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;dot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;W1&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;  &lt;span class="c1"&gt;# ReLU
&lt;/span&gt;    &lt;span class="n"&gt;logp&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;dot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;W2&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;sigmoid&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;logp&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;backward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;eph&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;epdlogp&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;epx&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;dW2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;dot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;eph&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;T&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;epdlogp&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;ravel&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="n"&gt;dh&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;outer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;epdlogp&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;W2&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
    &lt;span class="n"&gt;dh&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;eph&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;  &lt;span class="c1"&gt;# backprop through ReLU
&lt;/span&gt;    &lt;span class="n"&gt;dW1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;dot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dh&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;T&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;epx&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;W1&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;dW1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;W2&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;dW2&lt;/span&gt;&lt;span class="p"&gt;}&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;discount_rewards&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;r&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;gamma&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;Standard full-horizon discounting.&lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
    &lt;span class="n"&gt;discounted_r&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros_like&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;r&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;running_add&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;t&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;reversed&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;r&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;size&lt;/span&gt;&lt;span class="p"&gt;)):&lt;/span&gt;
        &lt;span class="n"&gt;running_add&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;running_add&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;gamma&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;r&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
        &lt;span class="n"&gt;discounted_r&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;running_add&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;discounted_r&lt;/span&gt;

&lt;span class="c1"&gt;# --- Initialisation ---
&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;make&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;CartPole-v1&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;observation&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;reset&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;D&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;observation&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# input dimension (4 for CartPole)
&lt;/span&gt;
&lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;W1&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;H&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;D&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;D&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;  &lt;span class="c1"&gt;# Xavier init
&lt;/span&gt;    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;W2&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;H&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;H&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;span class="p"&gt;}&lt;/span&gt;
&lt;span class="n"&gt;grad_buffer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros_like&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;items&lt;/span&gt;&lt;span class="p"&gt;()}&lt;/span&gt;
&lt;span class="n"&gt;rmsprop_cache&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros_like&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;items&lt;/span&gt;&lt;span class="p"&gt;()}&lt;/span&gt;

&lt;span class="c1"&gt;# --- Training loop ---
&lt;/span&gt;&lt;span class="n"&gt;xs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dlogps&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;drs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;episode_durations&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[],&lt;/span&gt; &lt;span class="p"&gt;[],&lt;/span&gt; &lt;span class="p"&gt;[],&lt;/span&gt; &lt;span class="p"&gt;[],&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
&lt;span class="n"&gt;episode_number&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
&lt;span class="n"&gt;t&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;

&lt;span class="k"&gt;while&lt;/span&gt; &lt;span class="n"&gt;episode_number&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;&lt;/span&gt; &lt;span class="mi"&gt;5000&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;observation&lt;/span&gt;
    &lt;span class="n"&gt;aprob&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;action&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;uniform&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;&lt;/span&gt; &lt;span class="n"&gt;aprob&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;

    &lt;span class="n"&gt;xs&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;hs&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;dlogps&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;action&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;aprob&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# policy gradient
&lt;/span&gt;
    &lt;span class="n"&gt;observation&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;drs&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt; &lt;span class="ow"&gt;or&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="n"&gt;episode_number&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;
        &lt;span class="n"&gt;episode_durations&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;t&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;

        &lt;span class="n"&gt;epx&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;vstack&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;xs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;eph&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;vstack&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;epdlogp&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;vstack&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dlogps&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;epr&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;vstack&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;drs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;xs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dlogps&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;drs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[],&lt;/span&gt; &lt;span class="p"&gt;[],&lt;/span&gt; &lt;span class="p"&gt;[],&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;

        &lt;span class="c1"&gt;# Discount and standardise rewards
&lt;/span&gt;        &lt;span class="n"&gt;discounted_epr&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;discount_rewards&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;epr&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;gamma&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;discounted_epr&lt;/span&gt; &lt;span class="o"&gt;-=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;discounted_epr&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;std&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;std&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;discounted_epr&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;std&lt;/span&gt; &lt;span class="o"&gt;&amp;gt;&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="n"&gt;discounted_epr&lt;/span&gt; &lt;span class="o"&gt;/=&lt;/span&gt; &lt;span class="n"&gt;std&lt;/span&gt;

        &lt;span class="c1"&gt;# The PG magic: weight gradients by advantage
&lt;/span&gt;        &lt;span class="n"&gt;epdlogp&lt;/span&gt; &lt;span class="o"&gt;*=&lt;/span&gt; &lt;span class="n"&gt;discounted_epr&lt;/span&gt;
        &lt;span class="n"&gt;grad&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;backward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;eph&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;epdlogp&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;epx&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="n"&gt;grad_buffer&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;grad&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;

        &lt;span class="c1"&gt;# RMSProp update every batch_size episodes
&lt;/span&gt;        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;episode_number&lt;/span&gt; &lt;span class="o"&gt;%&lt;/span&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;items&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
                &lt;span class="n"&gt;g&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;grad_buffer&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
                &lt;span class="n"&gt;rmsprop_cache&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;decay_rate&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;rmsprop_cache&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;decay_rate&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;g&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;
                &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;learning_rate&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;g&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;rmsprop_cache&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;epsilon&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
                &lt;span class="n"&gt;grad_buffer&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros_like&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;episode_number&lt;/span&gt; &lt;span class="o"&gt;%&lt;/span&gt; &lt;span class="mi"&gt;500&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="n"&gt;avg&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;episode_durations&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;:])&lt;/span&gt;
            &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Episode &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;episode_number&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;, 100-ep avg: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;avg&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="n"&gt;observation&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;reset&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="n"&gt;t&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;

&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;close&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Final 100-episode average: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;episode_durations&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;strong&gt;The result:&lt;/strong&gt; The agent converges to the 500-step maximum, using nothing but NumPy. No deep learning framework — we compute every gradient by hand.&lt;/p&gt;

&lt;h3&gt;
  
  
  Visualise the Learning
&lt;/h3&gt;



&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;matplotlib.pyplot&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;

&lt;span class="n"&gt;rolling&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;convolve&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;episode_durations&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;ones&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;/&lt;/span&gt;&lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mode&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;valid&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;subplots&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;rolling&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;b-&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;linewidth&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.8&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;axhline&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;500&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;g&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;linestyle&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;--&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Max score (500)&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_xlabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Episode&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_ylabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Duration (100-episode rolling avg)&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_title&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;REINFORCE on CartPole-v1&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_ylim&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;550&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;legend&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;grid&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;show&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fwve0mkiiz80rgqh8bv91.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fwve0mkiiz80rgqh8bv91.webp" alt="REINFORCE on CartPole-v1 — 100-episode rolling average reward climbing from ~25 to ~490" width="800" height="449"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h2&gt;
  
  
  What Just Happened?
&lt;/h2&gt;

&lt;p&gt;We built a complete RL agent with no frameworks — just a policy network, a gradient, and a reward signal. Let's walk through each piece.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Policy Network (4 → 100 → 1)
&lt;/h3&gt;

&lt;p&gt;The network is a two-layer perceptron that maps a 4-dimensional CartPole state to a single probability:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;State [x, ẋ, θ, θ̇]  →  Hidden (100 ReLU)  →  Output (sigmoid)  →  P(push right)
         4                    100                     1
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;dot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;W1&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;   &lt;span class="c1"&gt;# (100, 4) × (4,) → (100,)
&lt;/span&gt;    &lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;                  &lt;span class="c1"&gt;# ReLU activation
&lt;/span&gt;    &lt;span class="n"&gt;logp&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;dot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;W2&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# (100,) × (100,) → scalar
&lt;/span&gt;    &lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;sigmoid&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;logp&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;             &lt;span class="c1"&gt;# squash to [0, 1]
&lt;/span&gt;    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The output &lt;code&gt;p&lt;/code&gt; is the probability of pushing right. We sample from this Bernoulli distribution:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;action&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;uniform&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;&lt;/span&gt; &lt;span class="n"&gt;aprob&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This is fundamentally different from DQN's approach. DQN outputs Q-values for every action and picks the highest (deterministic, epsilon-greedy). Here, the network outputs a &lt;strong&gt;probability&lt;/strong&gt; and we &lt;strong&gt;sample&lt;/strong&gt; — the policy is stochastic by design. This built-in exploration means we don't need an epsilon schedule.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Policy Gradient Signal
&lt;/h3&gt;

&lt;p&gt;After taking an action, we record the "gradient that encourages the action that was taken":&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;dlogps&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;action&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;aprob&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;If we pushed right (&lt;code&gt;action=1&lt;/code&gt;) and &lt;code&gt;aprob=0.7&lt;/code&gt;, then &lt;code&gt;dlogp = 1 - 0.7 = 0.3&lt;/code&gt; — a positive gradient that nudges the network to make "push right" even more likely. If we pushed left (&lt;code&gt;action=0&lt;/code&gt;) and &lt;code&gt;aprob=0.7&lt;/code&gt;, then &lt;code&gt;dlogp = 0 - 0.7 = -0.7&lt;/code&gt; — a negative gradient that decreases the probability of pushing right (making left more likely).&lt;/p&gt;

&lt;p&gt;But at this point, we don't know if the action was &lt;em&gt;good&lt;/em&gt;. That's what the reward tells us.&lt;/p&gt;

&lt;h3&gt;
  
  
  The PG Magic Line
&lt;/h3&gt;

&lt;p&gt;This is where policy gradients really happen:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;epdlogp&lt;/span&gt; &lt;span class="o"&gt;*=&lt;/span&gt; &lt;span class="n"&gt;discounted_epr&lt;/span&gt;  &lt;span class="c1"&gt;# modulate gradient with advantage
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Every action's gradient gets multiplied by its discounted reward. Good actions (high reward) get their gradients amplified — "do more of this". Bad actions (low or negative reward) get their gradients flipped — "do less of that". This single line is the heart of REINFORCE.&lt;/p&gt;

&lt;p&gt;Think of it like a coach watching game film: every play gets a grade. Plays that led to scoring get reinforced; plays that led to turnovers get discouraged. The magnitude of the grade determines how strongly the feedback applies.&lt;/p&gt;

&lt;h3&gt;
  
  
  Reward Discounting
&lt;/h3&gt;

&lt;p&gt;CartPole gives +1 reward at every timestep the pole stays up. We discount future rewards so earlier actions, which contributed to a longer run, receive more credit:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;discount_rewards&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;r&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;gamma&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;discounted_r&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros_like&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;r&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;running_add&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;t&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;reversed&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;r&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;size&lt;/span&gt;&lt;span class="p"&gt;)):&lt;/span&gt;
        &lt;span class="n"&gt;running_add&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;running_add&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;gamma&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;r&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
        &lt;span class="n"&gt;discounted_r&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;running_add&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;discounted_r&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;With &lt;code&gt;gamma=0.99&lt;/code&gt;, an action taken 100 steps before the end still receives &lt;code&gt;$0.99^{100} \approx 0.37$&lt;/code&gt; of the terminal reward. This gives the network a smooth gradient across the episode: early actions in long episodes receive higher discounted returns than the same actions in short episodes, so the network learns to favour strategies that keep the pole up for longer.&lt;/p&gt;

&lt;h3&gt;
  
  
  Reward Standardisation (Variance Reduction)
&lt;/h3&gt;

&lt;p&gt;After discounting, we standardise the rewards to have zero mean and unit variance:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;discounted_epr&lt;/span&gt; &lt;span class="o"&gt;-=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;discounted_epr&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;std&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;std&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;discounted_epr&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;std&lt;/span&gt; &lt;span class="o"&gt;&amp;gt;&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;discounted_epr&lt;/span&gt; &lt;span class="o"&gt;/=&lt;/span&gt; &lt;span class="n"&gt;std&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This is critical for stable training. Without standardisation, the magnitude of the gradient depends on the absolute reward scale. With it, roughly half the actions get reinforced (above-average) and half get discouraged (below-average). It's a simple form of a &lt;strong&gt;baseline&lt;/strong&gt; — one of the most important variance reduction techniques in policy gradients.&lt;/p&gt;

&lt;h3&gt;
  
  
  Manual Backpropagation
&lt;/h3&gt;

&lt;p&gt;We compute gradients by hand, just as in the &lt;a href="https://sesen.ai/blog/backpropagation-neural-nets-from-first-principles" rel="noopener noreferrer"&gt;backpropagation post&lt;/a&gt;:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;backward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;eph&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;epdlogp&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;epx&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;dW2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;dot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;eph&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;T&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;epdlogp&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;ravel&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;     &lt;span class="c1"&gt;# gradient for output weights
&lt;/span&gt;    &lt;span class="n"&gt;dh&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;outer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;epdlogp&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;W2&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;        &lt;span class="c1"&gt;# backprop to hidden layer
&lt;/span&gt;    &lt;span class="n"&gt;dh&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;eph&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;                           &lt;span class="c1"&gt;# backprop through ReLU
&lt;/span&gt;    &lt;span class="n"&gt;dW1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;dot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dh&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;T&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;epx&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;                   &lt;span class="c1"&gt;# gradient for input weights
&lt;/span&gt;    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;W1&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;dW1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;W2&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;dW2&lt;/span&gt;&lt;span class="p"&gt;}&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The chain rule flows backwards: output layer gradient → hidden layer gradient (masked by ReLU) → input layer gradient. This is identical to standard backprop, except the "loss" is the policy gradient signal (&lt;code&gt;epdlogp&lt;/code&gt; already weighted by advantage).&lt;/p&gt;

&lt;h3&gt;
  
  
  Batch Gradient Accumulation
&lt;/h3&gt;

&lt;p&gt;Rather than updating after every episode, we accumulate gradients over &lt;code&gt;batch_size=5&lt;/code&gt; episodes:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;grad_buffer&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;grad&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;

&lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;episode_number&lt;/span&gt; &lt;span class="o"&gt;%&lt;/span&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="c1"&gt;# RMSProp update using accumulated gradients
&lt;/span&gt;    &lt;span class="bp"&gt;...&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This reduces gradient variance — each update reflects 5 episodes worth of experience rather than just one. It's the same idea as minibatch gradient descent in supervised learning.&lt;/p&gt;

&lt;h3&gt;
  
  
  RMSProp from Scratch
&lt;/h3&gt;

&lt;p&gt;The optimiser is RMSProp, implemented manually:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;rmsprop_cache&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;decay_rate&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;rmsprop_cache&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;decay_rate&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;g&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;
&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;learning_rate&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;g&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;rmsprop_cache&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;epsilon&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;RMSProp maintains a running average of squared gradients and divides by their root. This gives &lt;strong&gt;adaptive per-parameter learning rates&lt;/strong&gt;: parameters with consistently large gradients get smaller effective learning rates, and vice versa. The &lt;code&gt;decay_rate=0.99&lt;/code&gt; means the running average looks back roughly 100 updates.&lt;/p&gt;

&lt;h2&gt;
  
  
  Going Deeper
&lt;/h2&gt;

&lt;h3&gt;
  
  
  The Policy Gradient Theorem
&lt;/h3&gt;

&lt;p&gt;The goal of policy gradients is to maximise expected cumulative reward:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DJ%28%255Ctheta%29%2520%253D%2520%255Cmathbb%257BE%257D_%257B%255Ctau%2520%255Csim%2520%255Cpi_%255Ctheta%257D%255Cleft%255B%255Csum_%257Bt%253D0%257D%255E%257BT%257D%2520%255Cgamma%255Et%2520r_t%255Cright%255D" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DJ%28%255Ctheta%29%2520%253D%2520%255Cmathbb%257BE%257D_%257B%255Ctau%2520%255Csim%2520%255Cpi_%255Ctheta%257D%255Cleft%255B%255Csum_%257Bt%253D0%257D%255E%257BT%257D%2520%255Cgamma%255Et%2520r_t%255Cright%255D" alt="equation" width="246" height="75"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Where &lt;code&gt;$\tau$&lt;/code&gt; is a trajectory (sequence of states and actions) sampled under policy &lt;code&gt;$\pi_\theta$&lt;/code&gt;. The policy gradient theorem (&lt;a href="https://papers.nips.cc/paper/1999/hash/464d828b85b0bed98e80ade0a5c43b0f-Abstract.html" rel="noopener noreferrer"&gt;Sutton et al., 1999&lt;/a&gt;) tells us how to compute the gradient of this objective:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Cnabla_%255Ctheta%2520J%28%255Ctheta%29%2520%253D%2520%255Cmathbb%257BE%257D_%257B%255Ctau%2520%255Csim%2520%255Cpi_%255Ctheta%257D%255Cleft%255B%255Csum_%257Bt%253D0%257D%255E%257BT%257D%2520%255Cnabla_%255Ctheta%2520%255Clog%2520%255Cpi_%255Ctheta%28a_t%2520%257C%2520s_t%29%2520%255Ccdot%2520G_t%255Cright%255D" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Cnabla_%255Ctheta%2520J%28%255Ctheta%29%2520%253D%2520%255Cmathbb%257BE%257D_%257B%255Ctau%2520%255Csim%2520%255Cpi_%255Ctheta%257D%255Cleft%255B%255Csum_%257Bt%253D0%257D%255E%257BT%257D%2520%255Cnabla_%255Ctheta%2520%255Clog%2520%255Cpi_%255Ctheta%28a_t%2520%257C%2520s_t%29%2520%255Ccdot%2520G_t%255Cright%255D" alt="equation" width="440" height="75"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Where &lt;code&gt;$G_t = \sum_{k=t}^{T} \gamma^{k-t} r_k$&lt;/code&gt; is the return from timestep &lt;code&gt;$t$&lt;/code&gt;.&lt;/p&gt;

&lt;p&gt;The intuition: &lt;code&gt;$\nabla_\theta \log \pi_\theta(a_t | s_t)$&lt;/code&gt; points in the direction that makes action &lt;code&gt;$a_t$&lt;/code&gt; more likely. Multiplying by &lt;code&gt;$G_t$&lt;/code&gt; scales this — if the return was high, push hard in that direction; if low, push the other way.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Log-Likelihood Trick
&lt;/h3&gt;

&lt;p&gt;How do we differentiate through sampling? We can't backpropagate through &lt;code&gt;np.random.uniform()&lt;/code&gt;. The &lt;strong&gt;score function estimator&lt;/strong&gt; (also called the log-likelihood trick or REINFORCE trick) sidesteps this.&lt;/p&gt;

&lt;p&gt;For a Bernoulli policy with probability &lt;code&gt;$p$&lt;/code&gt;:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Cnabla_%255Ctheta%2520%255Clog%2520%255Cpi_%255Ctheta%28a%2520%257C%2520s%29%2520%253D%2520%255Cfrac%257Ba%2520-%2520p%257D%257Bp%281-p%29%257D%2520%255Ccdot%2520%255Cnabla_%255Ctheta%2520p" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Cnabla_%255Ctheta%2520%255Clog%2520%255Cpi_%255Ctheta%28a%2520%257C%2520s%29%2520%253D%2520%255Cfrac%257Ba%2520-%2520p%257D%257Bp%281-p%29%257D%2520%255Ccdot%2520%255Cnabla_%255Ctheta%2520p" alt="equation" width="328" height="51"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;But in practice, we use a simpler form. The gradient of the log-likelihood for a sigmoid output is just &lt;code&gt;$a - p$&lt;/code&gt;:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;dlogps&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;action&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;aprob&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# this IS the score function
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This is &lt;code&gt;action - aprob&lt;/code&gt; — exactly what the code computes. When &lt;code&gt;action=1&lt;/code&gt; and &lt;code&gt;aprob=0.7&lt;/code&gt;, the gradient is &lt;code&gt;0.3&lt;/code&gt;: "make right more likely." This gradient gets backpropagated through the network to update all weights.&lt;/p&gt;

&lt;h3&gt;
  
  
  Why Reward Standardisation Reduces Variance
&lt;/h3&gt;

&lt;p&gt;REINFORCE is an &lt;strong&gt;unbiased&lt;/strong&gt; estimator of the policy gradient, but it has &lt;strong&gt;high variance&lt;/strong&gt;. A single trajectory might get lucky or unlucky, leading to noisy gradient estimates.&lt;/p&gt;

&lt;p&gt;Subtracting a baseline &lt;code&gt;$b$&lt;/code&gt; from the return doesn't change the expected gradient (it's still unbiased) but can dramatically reduce variance:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Cnabla_%255Ctheta%2520J%28%255Ctheta%29%2520%253D%2520%255Cmathbb%257BE%257D%255Cleft%255B%255Cnabla_%255Ctheta%2520%255Clog%2520%255Cpi_%255Ctheta%28a_t%2520%257C%2520s_t%29%2520%255Ccdot%2520%28G_t%2520-%2520b%29%255Cright%255D" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Cnabla_%255Ctheta%2520J%28%255Ctheta%29%2520%253D%2520%255Cmathbb%257BE%257D%255Cleft%255B%255Cnabla_%255Ctheta%2520%255Clog%2520%255Cpi_%255Ctheta%28a_t%2520%257C%2520s_t%29%2520%255Ccdot%2520%28G_t%2520-%2520b%29%255Cright%255D" alt="equation" width="403" height="25"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Our code uses the &lt;strong&gt;mean reward&lt;/strong&gt; as the baseline:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;discounted_epr&lt;/span&gt; &lt;span class="o"&gt;-=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;discounted_epr&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;With this baseline, actions that performed better than average get reinforced, and below-average actions get discouraged. Without it, if all rewards are positive (as in CartPole, where every alive step gives +1), the gradient would reinforce &lt;em&gt;all&lt;/em&gt; actions — just some more than others. The baseline makes the signal much cleaner.&lt;/p&gt;

&lt;h3&gt;
  
  
  Hyperparameters
&lt;/h3&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Parameter&lt;/th&gt;
&lt;th&gt;What it controls&lt;/th&gt;
&lt;th&gt;Value&lt;/th&gt;
&lt;th&gt;Original (v0)&lt;/th&gt;
&lt;th&gt;Why changed&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;H&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;Hidden neurons&lt;/td&gt;
&lt;td&gt;100&lt;/td&gt;
&lt;td&gt;100&lt;/td&gt;
&lt;td&gt;Unchanged — enough for CartPole&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;learning_rate&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;RMSProp step size&lt;/td&gt;
&lt;td&gt;1e-3&lt;/td&gt;
&lt;td&gt;1e-4&lt;/td&gt;
&lt;td&gt;Longer episodes → more samples per update → can use larger steps&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;gamma&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;Discount factor&lt;/td&gt;
&lt;td&gt;0.99&lt;/td&gt;
&lt;td&gt;0.95&lt;/td&gt;
&lt;td&gt;Longer episodes (500 vs 200) need longer-horizon credit assignment&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;batch_size&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;Episodes per update&lt;/td&gt;
&lt;td&gt;5&lt;/td&gt;
&lt;td&gt;5&lt;/td&gt;
&lt;td&gt;Unchanged&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;decay_rate&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;RMSProp memory&lt;/td&gt;
&lt;td&gt;0.99&lt;/td&gt;
&lt;td&gt;0.99&lt;/td&gt;
&lt;td&gt;Unchanged&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;epsilon&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;RMSProp stability&lt;/td&gt;
&lt;td&gt;1e-5&lt;/td&gt;
&lt;td&gt;1e-5&lt;/td&gt;
&lt;td&gt;Unchanged&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;p&gt;The original code targeted CartPole-v0 (max 200 steps) and used a custom "blame window" that penalised only the last N=10 actions. This worked well for short episodes but created a performance ceiling on CartPole-v1 (max 500 steps) — as episodes get longer, the blame signal becomes proportionally smaller and gets washed out after standardisation.&lt;/p&gt;

&lt;p&gt;The fix is textbook REINFORCE: standard full-horizon discounting with &lt;code&gt;gamma=0.99&lt;/code&gt;. The higher gamma ensures actions 100+ steps before the terminal state still receive meaningful credit. The higher learning rate (1e-3) compensates for the longer episodes providing more gradient samples per update.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Learning rate is critical.&lt;/strong&gt; Too high and the policy oscillates wildly. Too low and learning crawls. With standard discounting and &lt;code&gt;gamma=0.99&lt;/code&gt;, &lt;code&gt;lr=1e-3&lt;/code&gt; converges in ~3,000 episodes.&lt;/p&gt;

&lt;h3&gt;
  
  
  Value-Based vs Policy-Based: When to Use Each
&lt;/h3&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;&lt;/th&gt;
&lt;th&gt;DQN (Value-Based)&lt;/th&gt;
&lt;th&gt;REINFORCE (Policy-Based)&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;&lt;strong&gt;Outputs&lt;/strong&gt;&lt;/td&gt;
&lt;td&gt;Q-values for each action&lt;/td&gt;
&lt;td&gt;Action probabilities&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;strong&gt;Action selection&lt;/strong&gt;&lt;/td&gt;
&lt;td&gt;Argmax (deterministic)&lt;/td&gt;
&lt;td&gt;Sample (stochastic)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;strong&gt;Exploration&lt;/strong&gt;&lt;/td&gt;
&lt;td&gt;Epsilon-greedy (bolted on)&lt;/td&gt;
&lt;td&gt;Built-in (stochastic policy)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;strong&gt;Action spaces&lt;/strong&gt;&lt;/td&gt;
&lt;td&gt;Discrete only&lt;/td&gt;
&lt;td&gt;Discrete and continuous&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;strong&gt;Sample efficiency&lt;/strong&gt;&lt;/td&gt;
&lt;td&gt;Higher (replay buffer)&lt;/td&gt;
&lt;td&gt;Lower (on-policy, no replay)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;strong&gt;Stability&lt;/strong&gt;&lt;/td&gt;
&lt;td&gt;Needs replay + target net&lt;/td&gt;
&lt;td&gt;Needs variance reduction&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;strong&gt;Convergence&lt;/strong&gt;&lt;/td&gt;
&lt;td&gt;To optimal Q → optimal policy&lt;/td&gt;
&lt;td&gt;Directly to (locally) optimal policy&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;p&gt;&lt;strong&gt;Use policy gradients when:&lt;/strong&gt;&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;The action space is continuous (robotics, continuous control)&lt;/li&gt;
&lt;li&gt;You want a stochastic policy (exploration, multi-modal strategies)&lt;/li&gt;
&lt;li&gt;The policy is simpler than the value function&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;strong&gt;Use DQN when:&lt;/strong&gt;&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;The action space is small and discrete&lt;/li&gt;
&lt;li&gt;Sample efficiency matters (you can't run millions of episodes)&lt;/li&gt;
&lt;li&gt;You have access to a simulator for experience replay&lt;/li&gt;
&lt;/ul&gt;

&lt;h3&gt;
  
  
  When NOT to Use Vanilla REINFORCE
&lt;/h3&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;High-dimensional action spaces&lt;/strong&gt; — Variance becomes unmanageable. Use actor-critic methods (A2C, PPO) that learn a value function baseline&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Sample-scarce settings&lt;/strong&gt; — REINFORCE is on-policy: every trajectory is used once then discarded. Off-policy methods like DQN are far more sample-efficient&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Long episodes&lt;/strong&gt; — The credit assignment problem worsens. Which of the 10,000 actions caused success? Actor-critic methods handle this better with per-step value estimates&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;When stability matters&lt;/strong&gt; — Vanilla REINFORCE can have large policy swings. PPO clips the gradient to prevent catastrophic updates&lt;/li&gt;
&lt;/ol&gt;

&lt;h2&gt;
  
  
  Deep Dive: The Papers
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Williams (1992) — The REINFORCE Paper
&lt;/h3&gt;

&lt;p&gt;The REINFORCE algorithm was introduced by Ronald J. Williams in his 1992 paper &lt;a href="https://link.springer.com/article/10.1007/BF00992696" rel="noopener noreferrer"&gt;&lt;em&gt;Simple Statistical Gradient-Following Algorithms for Connectionist Reinforcement Learning&lt;/em&gt;&lt;/a&gt;. This is one of the foundational papers in policy gradient methods.&lt;/p&gt;

&lt;p&gt;Williams framed the problem precisely: given a stochastic network (he called it a "connectionist network") that produces actions probabilistically, how do we adjust the weights to maximise expected reward?&lt;/p&gt;

&lt;p&gt;His key result:&lt;/p&gt;

&lt;blockquote&gt;
&lt;p&gt;"For each such algorithm, convergence to at least a local maximum in expected reinforcement is assured by a simple condition on the sequence of values used to scale the gradient estimate."&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;The REINFORCE update rule from the paper:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255CDelta%2520w_%257Bij%257D%2520%253D%2520%255Calpha_%257Bij%257D%28r%2520-%2520b_%257Bij%257D%29%2520%255Cfrac%257B%255Cpartial%2520%255Cln%2520g_i%257D%257B%255Cpartial%2520w_%257Bij%257D%257D" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255CDelta%2520w_%257Bij%257D%2520%253D%2520%255Calpha_%257Bij%257D%28r%2520-%2520b_%257Bij%257D%29%2520%255Cfrac%257B%255Cpartial%2520%255Cln%2520g_i%257D%257B%255Cpartial%2520w_%257Bij%257D%257D" alt="equation" width="266" height="61"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Where:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;code&gt;$w_{ij}$&lt;/code&gt; — weight from unit &lt;code&gt;$j$&lt;/code&gt; to unit &lt;code&gt;$i$&lt;/code&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;$\alpha_{ij}$&lt;/code&gt; — learning rate&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;$r$&lt;/code&gt; — the reinforcement (reward)&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;$b_{ij}$&lt;/code&gt; — a reinforcement baseline (reduces variance without introducing bias)&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;$g_i$&lt;/code&gt; — the probability of the output of unit &lt;code&gt;$i$&lt;/code&gt;
&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;In our code, this maps directly:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;code&gt;$\frac{\partial \ln g_i}{\partial w_{ij}}$&lt;/code&gt; → &lt;code&gt;dlogps&lt;/code&gt; (the score function, computed via backprop)&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;$r - b_{ij}$&lt;/code&gt; → &lt;code&gt;discounted_epr&lt;/code&gt; (standardised, which implicitly subtracts a mean baseline)&lt;/li&gt;
&lt;li&gt;The product &lt;code&gt;epdlogp *= discounted_epr&lt;/code&gt; is the REINFORCE update&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;Williams proved that this estimator is &lt;strong&gt;unbiased&lt;/strong&gt;: in expectation, it equals the true policy gradient regardless of the baseline &lt;code&gt;$b$&lt;/code&gt;. The baseline only affects variance — choosing it well (e.g., as the mean reward) can dramatically speed convergence.&lt;/p&gt;

&lt;h3&gt;
  
  
  Sutton, McAllester, Singh &amp;amp; Mansour (1999) — The Policy Gradient Theorem
&lt;/h3&gt;

&lt;p&gt;The policy gradient theorem, proved in &lt;a href="https://papers.nips.cc/paper/1999/hash/464d828b85b0bed98e80ade0a5c43b0f-Abstract.html" rel="noopener noreferrer"&gt;&lt;em&gt;Policy Gradient Methods for Reinforcement Learning with Function Approximation&lt;/em&gt;&lt;/a&gt;, generalised Williams' result to arbitrary function approximators:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Cnabla_%255Ctheta%2520J%28%255Ctheta%29%2520%253D%2520%255Csum_s%2520d%255E%257B%255Cpi%257D%28s%29%2520%255Csum_a%2520%255Cnabla_%255Ctheta%2520%255Cpi_%255Ctheta%28a%257Cs%29%2520Q%255E%257B%255Cpi%257D%28s%252Ca%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Cnabla_%255Ctheta%2520J%28%255Ctheta%29%2520%253D%2520%255Csum_s%2520d%255E%257B%255Cpi%257D%28s%29%2520%255Csum_a%2520%255Cnabla_%255Ctheta%2520%255Cpi_%255Ctheta%28a%257Cs%29%2520Q%255E%257B%255Cpi%257D%28s%252Ca%29" alt="equation" width="437" height="53"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Where &lt;code&gt;$d^{\pi}(s)$&lt;/code&gt; is the stationary state distribution under policy &lt;code&gt;$\pi$&lt;/code&gt;. The theorem shows that the policy gradient doesn't depend on the gradient of the state distribution — a surprising and essential result that makes policy gradient methods practical.&lt;/p&gt;

&lt;h3&gt;
  
  
  Karpathy (2016) — Deep RL: Pong from Pixels
&lt;/h3&gt;

&lt;p&gt;The original code was inspired by Andrej Karpathy's influential blog post &lt;a href="https://karpathy.github.io/2016/05/31/rl/" rel="noopener noreferrer"&gt;&lt;em&gt;Deep Reinforcement Learning: Pong from Pixels&lt;/em&gt;&lt;/a&gt;. Karpathy demonstrated that a simple two-layer network trained with REINFORCE could learn to play Atari Pong directly from raw pixels — all in ~130 lines of Python.&lt;/p&gt;

&lt;p&gt;The key architectural choices we inherit:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Two-layer network&lt;/strong&gt; with ReLU hidden layer and sigmoid output&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Manual NumPy backprop&lt;/strong&gt; — no framework dependency&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;RMSProp&lt;/strong&gt; as the optimiser&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Reward discounting&lt;/strong&gt; with standardisation&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;Karpathy's version used the raw Pong pixel difference as input (preprocessed frames). Our CartPole version uses the 4-dimensional state vector directly, but the algorithm is identical.&lt;/p&gt;

&lt;h3&gt;
  
  
  The REINFORCE Algorithm (Pseudocode)
&lt;/h3&gt;

&lt;p&gt;From Sutton &amp;amp; Barto (2018), Section 13.3:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;Initialize policy parameters θ arbitrarily
For each episode:
    Generate trajectory τ = (s₀, a₀, r₁, s₁, a₁, ..., sT) following π_θ
    For each step t = 0, 1, ..., T-1:
        Gₜ ← Σ_{k=t+1}^{T} γ^{k-t-1} rₖ       (return from step t)
        θ ← θ + α γ^t Gₜ ∇_θ ln π_θ(aₜ|sₜ)    (policy gradient update)
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  Our Implementation vs the Theory
&lt;/h3&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Theory (Williams, 1992)&lt;/th&gt;
&lt;th&gt;Our code&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;
&lt;code&gt;$\frac{\partial \ln g_i}{\partial w_{ij}}$&lt;/code&gt; (score function)&lt;/td&gt;
&lt;td&gt;
&lt;code&gt;action - aprob&lt;/code&gt; backpropagated through network&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;
&lt;code&gt;$r - b_{ij}$&lt;/code&gt; (reward minus baseline)&lt;/td&gt;
&lt;td&gt;
&lt;code&gt;discounted_epr&lt;/code&gt; after mean subtraction&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;$\Delta w = \alpha (r-b) \nabla \ln \pi$&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;
&lt;code&gt;epdlogp *= discounted_epr&lt;/code&gt; then &lt;code&gt;backward()&lt;/code&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Single-sample update&lt;/td&gt;
&lt;td&gt;Batch of 5 episodes&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Generic optimiser&lt;/td&gt;
&lt;td&gt;RMSProp (adaptive learning rates)&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;h3&gt;
  
  
  What Came After
&lt;/h3&gt;

&lt;p&gt;REINFORCE spawned the entire family of modern policy gradient methods:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Actor-Critic&lt;/strong&gt; (&lt;a href="https://papers.nips.cc/paper/1999/hash/6449f44a102fde848669bdd9eb6b76fa-Abstract.html" rel="noopener noreferrer"&gt;Konda &amp;amp; Tsitsiklis, 2000&lt;/a&gt;) — Learn a value function baseline alongside the policy&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;A3C&lt;/strong&gt; (&lt;a href="https://arxiv.org/abs/1602.01783" rel="noopener noreferrer"&gt;Mnih et al., 2016&lt;/a&gt;) — Asynchronous actor-critic with parallel workers&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;PPO&lt;/strong&gt; (&lt;a href="https://arxiv.org/abs/1707.06347" rel="noopener noreferrer"&gt;Schulman et al., 2017&lt;/a&gt;) — Clipped surrogate objective for stable updates; the default algorithm behind ChatGPT's RLHF&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;DDPG&lt;/strong&gt; (&lt;a href="https://arxiv.org/abs/1509.02971" rel="noopener noreferrer"&gt;Lillicrap et al., 2016&lt;/a&gt;) — Continuous-action policy gradients with a deterministic policy&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;SAC&lt;/strong&gt; (&lt;a href="https://arxiv.org/abs/1801.01290" rel="noopener noreferrer"&gt;Haanoja et al., 2018&lt;/a&gt;) — Entropy-regularised policy gradients for robust exploration&lt;/li&gt;
&lt;/ul&gt;

&lt;h3&gt;
  
  
  Historical Context
&lt;/h3&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Williams (1992)&lt;/strong&gt; — REINFORCE: the first general-purpose policy gradient algorithm&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Sutton et al. (1999)&lt;/strong&gt; — Policy gradient theorem: proves the gradient formula for function approximators&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Kakade (2001)&lt;/strong&gt; — Natural policy gradients: use Fisher information for better updates&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Schulman et al. (2015)&lt;/strong&gt; — TRPO: trust regions for policy optimisation&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Schulman et al. (2017)&lt;/strong&gt; — PPO: simplified TRPO with clipped objective; now the industry standard&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;OpenAI (2017)&lt;/strong&gt; — RLHF: policy gradients for aligning language models&lt;/li&gt;
&lt;/ul&gt;

&lt;h3&gt;
  
  
  Further Reading
&lt;/h3&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://link.springer.com/article/10.1007/BF00992696" rel="noopener noreferrer"&gt;Williams (1992)&lt;/a&gt; — &lt;em&gt;Simple Statistical Gradient-Following Algorithms for Connectionist Reinforcement Learning&lt;/em&gt; — the REINFORCE paper&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://papers.nips.cc/paper/1999/hash/464d828b85b0bed98e80ade0a5c43b0f-Abstract.html" rel="noopener noreferrer"&gt;Sutton et al. (1999)&lt;/a&gt; — &lt;em&gt;Policy Gradient Methods for Reinforcement Learning with Function Approximation&lt;/em&gt; — the policy gradient theorem&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://karpathy.github.io/2016/05/31/rl/" rel="noopener noreferrer"&gt;Karpathy (2016)&lt;/a&gt; — &lt;em&gt;Deep Reinforcement Learning: Pong from Pixels&lt;/em&gt; — the blog post that inspired the original code&lt;/li&gt;
&lt;li&gt;
&lt;a href="http://incompleteideas.net/book/the-book-2nd.html" rel="noopener noreferrer"&gt;Sutton &amp;amp; Barto (2018)&lt;/a&gt; — Chapter 13 (Policy Gradient Methods) — freely available online&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Interactive Tools
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/q-learning-visualizer" rel="noopener noreferrer"&gt;Q-Learning Visualiser&lt;/a&gt; — Compare value-based RL with the policy gradient approach covered in this post&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Related Posts
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/deep-q-networks-experience-replay-target-networks" rel="noopener noreferrer"&gt;Deep Q-Networks: Experience Replay and Target Networks&lt;/a&gt; — Value-based RL with neural networks, the approach we move beyond here&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/q-learning-frozen-lake-from-scratch" rel="noopener noreferrer"&gt;Q-Learning from Scratch&lt;/a&gt; — Tabular value-based RL, the foundation for understanding the value→policy shift&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/backpropagation-neural-nets-from-first-principles" rel="noopener noreferrer"&gt;Backpropagation Demystified&lt;/a&gt; — We implement backprop manually here too; this post covers the fundamentals&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/genetic-algorithms-from-scratch" rel="noopener noreferrer"&gt;Genetic Algorithms&lt;/a&gt; — Gradient-free optimisation, an alternative to policy gradients for non-differentiable objectives&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Try It Yourself
&lt;/h2&gt;

&lt;p&gt;The &lt;a href="https://colab.research.google.com/github/zhubarb/sesen_ai_ml_tutorials/blob/main/notebooks/reinforcement-learning/policy_gradient_cartpole.ipynb" rel="noopener noreferrer"&gt;interactive notebook&lt;/a&gt; includes exercises:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Remove reward shaping&lt;/strong&gt; — Replace &lt;code&gt;discount_rewards&lt;/code&gt; with standard full-horizon discounting. How does training speed change?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Vary the blame window&lt;/strong&gt; — Try &lt;code&gt;$N \in \{5, 10, 20, 50\}$&lt;/code&gt;. How does the learning curve respond?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Remove standardisation&lt;/strong&gt; — Comment out the mean/std normalisation of rewards. Does the agent still learn?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Gamma sweep&lt;/strong&gt; — Try &lt;code&gt;$\gamma \in \{0.8, 0.95, 0.99, 0.999\}$&lt;/code&gt;. How does the discount factor affect learning?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Compare with DQN&lt;/strong&gt; — Plot the REINFORCE learning curve alongside &lt;a href="https://sesen.ai/blog/deep-q-networks-experience-replay-target-networks" rel="noopener noreferrer"&gt;DQN's&lt;/a&gt; on the same axes. Which learns faster? Which is more stable?&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;REINFORCE showed that you can optimise a policy directly — no Q-values needed. The gradient signal is noisy, but with variance reduction (baselines, reward standardisation, batch updates) it works remarkably well. The same core idea — gradient ascent on expected reward — underpins every modern policy gradient algorithm, from PPO to the RLHF that aligns large language models. The fundamentals haven't changed since Williams (1992); only the variance reduction has gotten better.&lt;/p&gt;

&lt;h2&gt;
  
  
  Frequently Asked Questions
&lt;/h2&gt;

&lt;h3&gt;
  
  
  What is the difference between policy gradient methods and value-based methods like DQN?
&lt;/h3&gt;

&lt;p&gt;Value-based methods learn a value function (such as Q-values) and derive a policy indirectly by choosing the action with the highest value. Policy gradient methods parameterise the policy directly and optimise it via gradient ascent on expected reward. The key advantage of policy gradients is that they naturally handle continuous action spaces and produce stochastic policies with built-in exploration.&lt;/p&gt;

&lt;h3&gt;
  
  
  Why is reward standardisation important in REINFORCE?
&lt;/h3&gt;

&lt;p&gt;Without standardisation, all rewards in CartPole are positive (+1 per timestep), so the gradient reinforces every action, just some more than others. Subtracting the mean makes roughly half the actions receive positive reinforcement (above average) and half negative (below average). This acts as a simple baseline that dramatically reduces gradient variance and stabilises training.&lt;/p&gt;

&lt;h3&gt;
  
  
  What does the discount factor gamma control?
&lt;/h3&gt;

&lt;p&gt;Gamma determines how much weight future rewards receive relative to immediate rewards. A value of 0.99 means an action taken 100 steps before the end still receives about 37% of the terminal reward. Higher gamma values encourage long-term planning but increase variance, while lower values make the agent more short-sighted but produce more stable gradients.&lt;/p&gt;

&lt;h3&gt;
  
  
  Why was RMSProp chosen as the optimiser instead of plain gradient ascent?
&lt;/h3&gt;

&lt;p&gt;RMSProp maintains a running average of squared gradients and adapts the learning rate for each parameter individually. Parameters with consistently large gradients get smaller effective learning rates, preventing them from dominating the update. This adaptive behaviour is especially important in reinforcement learning, where gradient magnitudes can vary wildly across episodes and parameters.&lt;/p&gt;

&lt;h3&gt;
  
  
  When should I use REINFORCE versus more advanced algorithms like PPO?
&lt;/h3&gt;

&lt;p&gt;Vanilla REINFORCE is best suited for simple environments, educational purposes, and situations where you want full control over the implementation. For complex environments, long episodes, or production systems, PPO and other actor-critic methods are preferred because they use a learned value function baseline that reduces variance and clip the gradient to prevent catastrophic policy updates.&lt;/p&gt;

&lt;h3&gt;
  
  
  Why does REINFORCE accumulate gradients over multiple episodes before updating?
&lt;/h3&gt;

&lt;p&gt;Accumulating gradients over a batch of episodes (5 in this implementation) reduces the variance of the gradient estimate. A single episode can be noisy due to the stochastic nature of both the policy and the environment. Averaging over multiple episodes produces a smoother, more reliable gradient signal, similar to minibatch gradient descent in supervised learning.&lt;/p&gt;

</description>
      <category>reinforcementlearning</category>
      <category>deeplearning</category>
      <category>optimisation</category>
    </item>
  </channel>
</rss>
