<?xml version="1.0" encoding="UTF-8"?>
<rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom" xmlns:dc="http://purl.org/dc/elements/1.1/">
  <channel>
    <title>DEV Community: Berkan Sesen</title>
    <description>The latest articles on DEV Community by Berkan Sesen (@berkan_sesen).</description>
    <link>https://dev.to/berkan_sesen</link>
    <image>
      <url>https://media2.dev.to/dynamic/image/width=90,height=90,fit=cover,gravity=auto,format=auto/https:%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Fuser%2Fprofile_image%2F3843317%2F81217b17-b750-4b21-a7f8-d6dbafdcf816.jpg</url>
      <title>DEV Community: Berkan Sesen</title>
      <link>https://dev.to/berkan_sesen</link>
    </image>
    <atom:link rel="self" type="application/rss+xml" href="https://dev.to/feed/berkan_sesen"/>
    <language>en</language>
    <item>
      <title>Cointegration and Pairs Trading: When Time Series Move Together</title>
      <dc:creator>Berkan Sesen</dc:creator>
      <pubDate>Sun, 24 May 2026 10:32:53 +0000</pubDate>
      <link>https://dev.to/berkan_sesen/cointegration-and-pairs-trading-when-time-series-move-together-2cg6</link>
      <guid>https://dev.to/berkan_sesen/cointegration-and-pairs-trading-when-time-series-move-together-2cg6</guid>
      <description>&lt;p&gt;Pairs trading rests on a simple idea: find two assets that move together, wait for them to diverge, and bet on convergence. The hard part is defining "move together." Two commodity ETFs, EWA (Australia) and EWC (Canada), show a 0.95 correlation over a multi-year window. A mean-reversion trader sees that number and assumes the spread will keep snapping back. But then the spread drifts apart and stays apart for months. The correlation was real; the strategy still bled money. The problem is that correlation tells you whether two series tend to move in the same direction. Cointegration tells you whether they are bound together by a long-run equilibrium, so that any deviation is temporary and will correct itself.&lt;/p&gt;

&lt;p&gt;The distinction matters because most financial time series are non-stationary (they wander without a fixed mean). Two non-stationary series can be highly correlated by pure coincidence (the "spurious regression" problem identified by &lt;a href="https://doi.org/10.1016/0304-4076(74)90034-7" rel="noopener noreferrer"&gt;Granger and Newbold, 1974&lt;/a&gt;). Cointegration is the formal test for whether their &lt;em&gt;difference&lt;/em&gt; (or some linear combination) is stationary, meaning it genuinely reverts to a mean.&lt;/p&gt;

&lt;p&gt;By the end of this post, you'll test for cointegration using both the Engle-Granger and Johansen methods, understand when and why they disagree, and build a simple pairs trading strategy on real ETF data.&lt;/p&gt;

&lt;h2&gt;
  
  
  The Data: Country ETF Pairs
&lt;/h2&gt;

&lt;p&gt;We use two iShares country ETFs: EWA (Australia) and EWC (Canada). Both countries are commodity exporters with similar economic drivers (mining, energy, agriculture), so there's a fundamental reason to expect a long-run relationship. This is the same pair used in the original R analysis we're translating.&lt;/p&gt;

&lt;p&gt;For comparison, we also test GLD (gold) and GDX (gold miners). Despite the obvious connection, gold miners have idiosyncratic risks (management, costs, leverage) that can break cointegration.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fc172hvtzk3s5qpcjg74y.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fc172hvtzk3s5qpcjg74y.webp" alt="Dual-axis time series of EWA and EWC ETF prices from 2007 to 2023, showing similar patterns with occasional divergences" width="800" height="416"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The two ETFs clearly track each other over 17 years. They crash together in 2008, recover together, and diverge temporarily during COVID before reconverging. But visual similarity isn't proof of cointegration. We need a formal test.&lt;/p&gt;

&lt;h2&gt;
  
  
  Quick Win: Testing for Cointegration
&lt;/h2&gt;

&lt;p&gt;Click the badge to run this yourself:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://colab.research.google.com/github/zhubarb/sesen_ai_ml_tutorials/blob/main/notebooks/time-series/cointegration_pairs_trading.ipynb" rel="noopener noreferrer"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fcolab.research.google.com%2Fassets%2Fcolab-badge.svg" alt="Open In Colab" width="117" height="20"&gt;&lt;/a&gt;&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;pandas&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;pd&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;yfinance&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;yf&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;statsmodels.tsa.stattools&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;adfuller&lt;/span&gt;

&lt;span class="c1"&gt;# Download EWA and EWC adjusted close prices
&lt;/span&gt;&lt;span class="n"&gt;ewa&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;yf&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;download&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;EWA&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;start&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;2007-01-01&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;end&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;2023-12-31&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                   &lt;span class="n"&gt;auto_adjust&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;progress&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Close&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;span class="n"&gt;ewc&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;yf&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;download&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;EWC&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;start&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;2007-01-01&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;end&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;2023-12-31&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                   &lt;span class="n"&gt;auto_adjust&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;progress&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Close&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;

&lt;span class="c1"&gt;# Align on common trading days
&lt;/span&gt;&lt;span class="n"&gt;common&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;ewa&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;index&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;intersection&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;ewc&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;index&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ewa&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ewc&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;ewa&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;loc&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;common&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;ewc&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;loc&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;common&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;ewa&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; trading days, &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;ewa&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;index&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;date&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; to &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;ewa&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;index&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;date&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;4278 trading days, 2007-01-03 to 2023-12-29
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The Engle-Granger test is two steps: regress one series on the other, then test whether the residuals are stationary.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;statsmodels.regression.linear_model&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;OLS&lt;/span&gt;

&lt;span class="c1"&gt;# Regress EWC on EWA (no intercept, following the original R code)
&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;OLS&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;ewc&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ewa&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;spread&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;resid&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Hedge ratio: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# ADF test on the residuals
&lt;/span&gt;&lt;span class="n"&gt;adf_stat&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;adf_pval&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;crit_vals&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;adfuller&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;spread&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;regression&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;n&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;ADF statistic: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;adf_stat&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;p-value: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;adf_pval&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;Hedge ratio: 1.5674
ADF statistic: -3.1704
p-value: 0.0015
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The ADF test rejects the unit root null hypothesis at the 1% level (p = 0.0015). The spread between EWC and 1.57 times EWA is stationary, meaning these two ETFs are cointegrated. Any deviation from the long-run relationship tends to correct itself.&lt;/p&gt;

&lt;p&gt;The spread looks like this:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fhgn9t4a5vbc1h4i2p2sm.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fhgn9t4a5vbc1h4i2p2sm.webp" alt="Cointegration spread oscillating around zero with 2-sigma bands, showing mean-reverting behaviour" width="800" height="409"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The spread wanders but always returns to the mean. It doesn't drift permanently in one direction like a random walk would. This mean-reverting property is exactly what makes cointegration useful for trading.&lt;/p&gt;

&lt;h2&gt;
  
  
  What Just Happened?
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Stationarity: The Key Idea
&lt;/h3&gt;

&lt;p&gt;A stationary time series has a constant mean and variance over time. If you pick any window, the statistics look roughly the same. Stock prices are almost never stationary (they trend up or down), but the &lt;em&gt;spread&lt;/em&gt; between two cointegrated stocks can be.&lt;/p&gt;

&lt;p&gt;The Augmented Dickey-Fuller (ADF) test checks whether a series has a unit root (non-stationary). The null hypothesis is "this series has a unit root" (bad for us). A small p-value means we can reject the null and conclude the series is stationary (good for us).&lt;/p&gt;

&lt;h3&gt;
  
  
  The Engle-Granger Two-Step Method
&lt;/h3&gt;

&lt;p&gt;&lt;a href="https://doi.org/10.2307/1913236" rel="noopener noreferrer"&gt;Engle and Granger (1987)&lt;/a&gt; proposed a beautifully simple procedure:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Regress&lt;/strong&gt; one time series on the other: &lt;code&gt;$\text{EWC}_t = \beta \cdot \text{EWA}_t + \varepsilon_t$&lt;/code&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Test&lt;/strong&gt; the residuals &lt;code&gt;$\varepsilon_t$&lt;/code&gt; for stationarity using the ADF test&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;If the residuals are stationary, the two series are cointegrated with cointegrating vector &lt;code&gt;$[1, -\beta]$&lt;/code&gt;. The coefficient &lt;code&gt;$\beta = 1.57$&lt;/code&gt; is the &lt;strong&gt;hedge ratio&lt;/strong&gt;: for every dollar of EWC, you hold $1.57 of EWA to neutralise the common trend.&lt;/p&gt;

&lt;p&gt;One subtlety: which series is the dependent variable matters. The original R code runs both directions (EWC on EWA and EWA on EWC) and picks the regression with the most negative ADF statistic. In our case, both directions give similar results.&lt;/p&gt;

&lt;h3&gt;
  
  
  Why Not Just Use Correlation?
&lt;/h3&gt;

&lt;p&gt;Two series can have a correlation of 0.99 and still not be cointegrated. Imagine two random walks that happen to trend upward over the same period. Their correlation will be high, but their spread will drift without bound. Conversely, two cointegrated series can have low short-term correlation if they temporarily diverge before snapping back. &lt;a href="https://sesen.ai/blog/linear-regression-five-ways" rel="noopener noreferrer"&gt;Correlation measures co-movement; cointegration measures co-wandering with a leash.&lt;/a&gt;&lt;/p&gt;

&lt;h3&gt;
  
  
  The Johansen Test: A Multivariate Approach
&lt;/h3&gt;

&lt;p&gt;The Engle-Granger method is limited to pairs. The Johansen test, introduced by &lt;a href="https://doi.org/10.2307/2938278" rel="noopener noreferrer"&gt;Johansen (1991)&lt;/a&gt;, handles any number of time series simultaneously. It works through a vector autoregression (VAR) framework and estimates the &lt;strong&gt;cointegration rank&lt;/strong&gt;: how many independent cointegrating relationships exist among the series.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;statsmodels.tsa.vector_ar.vecm&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;coint_johansen&lt;/span&gt;

&lt;span class="n"&gt;data&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;column_stack&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;ewa&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ewc&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;span class="n"&gt;result&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;coint_johansen&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;det_order&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;k_ar_diff&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Trace statistic (r=0): &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;result&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;lr1&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;95% critical value:    &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;result&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cvt&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;Trace statistic (r=0): 16.66
95% critical value:    15.49
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The trace statistic (16.66) exceeds the 95% critical value (15.49), so Johansen also rejects the null of no cointegration. Both methods agree for EWA/EWC.&lt;/p&gt;

&lt;h3&gt;
  
  
  When Tests Disagree
&lt;/h3&gt;

&lt;p&gt;The original R code used a shorter date range where Engle-Granger found marginal cointegration (p = 7%) but Johansen did not. This highlights an important practical point: cointegration tests are sensitive to sample period, structural breaks, and lag selection. The 2008 financial crisis, for instance, can distort the relationship. When the two tests disagree, it's usually a sign that cointegration is weak or regime-dependent, not a reason to pick the more favourable result.&lt;/p&gt;

&lt;h2&gt;
  
  
  Going Deeper
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Pairs Trading: Exploiting Mean Reversion
&lt;/h3&gt;

&lt;p&gt;If the spread is stationary, we can trade its mean-reversion. The strategy is simple:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;Compute a rolling z-score of the spread: &lt;code&gt;$z_t = \frac{s_t - \bar{s}_{60}}{\sigma_{60}}$&lt;/code&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Buy&lt;/strong&gt; the spread when &lt;code&gt;$z &amp;lt; -2$&lt;/code&gt; (spread is unusually cheap)&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Sell&lt;/strong&gt; the spread when &lt;code&gt;$z &amp;gt; +2$&lt;/code&gt; (spread is unusually expensive)&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Close&lt;/strong&gt; when &lt;code&gt;$z$&lt;/code&gt; crosses zero (spread has reverted)&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;"Buying the spread" means going long EWC and short EWA (proportional to the hedge ratio). "Selling the spread" means the opposite.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F5ir1w36346dd6zwrlxgg.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F5ir1w36346dd6zwrlxgg.webp" alt="Rolling z-score of the spread with buy and sell thresholds at negative and positive 2" width="800" height="445"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The z-score oscillates between roughly -4 and +4, regularly crossing the trading thresholds. Each crossing is a potential trade entry or exit.&lt;/p&gt;

&lt;h3&gt;
  
  
  Backtest Results
&lt;/h3&gt;

&lt;p&gt;Running this simple strategy over 17 years of EWA/EWC data:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fpefglknns3ytabykb0x6.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fpefglknns3ytabykb0x6.webp" alt="Cumulative PnL curve and position indicator for the pairs trading backtest" width="800" height="637"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The strategy generates a cumulative PnL of about $19 per unit of spread, with 135 trades and an annualised Sharpe ratio of 0.69. The equity curve is mostly upward-sloping, with a significant drawdown during 2012-2014 when the spread drifted for an extended period.&lt;/p&gt;

&lt;p&gt;This is a toy backtest (no transaction costs, slippage, or financing costs). Real implementation requires careful attention to execution, but the core signal (mean-reverting spread) is genuine.&lt;/p&gt;

&lt;h3&gt;
  
  
  A Pair That Fails: GLD vs GDX
&lt;/h3&gt;

&lt;p&gt;To see what non-cointegration looks like, consider GLD (gold) and GDX (gold miners). Despite the intuitive connection, gold miners have company-specific risks that break the long-run equilibrium.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F7b9kxifdaajlmm2gclu5.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F7b9kxifdaajlmm2gclu5.webp" alt="ADF test comparison showing EWA/EWC clearly rejecting the unit root while GLD/GDX does not" width="800" height="346"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The ADF test statistic for EWA/EWC (-3.17) is well past all critical values. For GLD/GDX (-1.64), it fails even the 10% level. The Johansen test confirms: GLD/GDX shows no evidence of cointegration (trace stat 13.38 &amp;lt; 15.49 critical value).&lt;/p&gt;

&lt;p&gt;This is why fundamental reasoning alone isn't enough. You need the statistical test.&lt;/p&gt;

&lt;h3&gt;
  
  
  Autocorrelation: Visual Evidence of Stationarity
&lt;/h3&gt;

&lt;p&gt;The autocorrelation function (ACF) of the spread provides visual confirmation:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fjkhpdup047diod0k723r.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fjkhpdup047diod0k723r.webp" alt="Autocorrelation plot of the spread showing high persistence but gradual decay" width="800" height="442"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The ACF decays slowly from 1.0, which is typical for a stationary but highly persistent process. A truly non-stationary series would show autocorrelations that barely decay at all. The gradual decline confirms the spread reverts, but slowly (mean half-life of roughly 3 months based on the decay rate).&lt;/p&gt;

&lt;h3&gt;
  
  
  Hyperparameter Choices
&lt;/h3&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Parameter&lt;/th&gt;
&lt;th&gt;Value&lt;/th&gt;
&lt;th&gt;Why&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;ETF pair&lt;/td&gt;
&lt;td&gt;EWA/EWC&lt;/td&gt;
&lt;td&gt;Original R code pair; commodity exporters with fundamental economic link&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Date range&lt;/td&gt;
&lt;td&gt;2007-2023&lt;/td&gt;
&lt;td&gt;17 years covering multiple market regimes (GFC, COVID)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;ADF regression&lt;/td&gt;
&lt;td&gt;No intercept&lt;/td&gt;
&lt;td&gt;Matches original R code (&lt;code&gt;type="nc"&lt;/code&gt;); spread should be zero-mean&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Johansen settings&lt;/td&gt;
&lt;td&gt;&lt;code&gt;det_order=0, k_ar_diff=1&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;Matches R &lt;code&gt;ecdet="none", K=2&lt;/code&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Z-score window&lt;/td&gt;
&lt;td&gt;60 days&lt;/td&gt;
&lt;td&gt;~3 months; balances responsiveness with stability&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Entry threshold&lt;/td&gt;
&lt;td&gt;±2σ&lt;/td&gt;
&lt;td&gt;Standard for pairs trading; ~5% of observations in tails&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Exit threshold&lt;/td&gt;
&lt;td&gt;0&lt;/td&gt;
&lt;td&gt;Close when spread returns to rolling mean&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;h2&gt;
  
  
  Where This Comes From
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Engle and Granger (1987): The Nobel Prize Paper
&lt;/h3&gt;

&lt;p&gt;Robert Engle and Clive Granger introduced cointegration in their &lt;a href="https://doi.org/10.2307/1913236" rel="noopener noreferrer"&gt;1987 paper&lt;/a&gt; "Co-Integration and Error Correction: Representation, Estimation, and Testing", published in Econometrica. The work earned Granger the Nobel Prize in Economics in 2003 (shared with Engle, who was recognised for ARCH models).&lt;/p&gt;

&lt;p&gt;Their key insight was that while individual economic time series may be non-stationary (integrated of order 1, or I(1)), linear combinations of them can be stationary (I(0)). This formalised the intuition that certain economic variables are "tied together" by equilibrium forces, even though each variable wanders on its own.&lt;/p&gt;

&lt;blockquote&gt;
&lt;p&gt;"A test for cointegration can be thought of as a pre-test to avoid 'spurious regression' situations."&lt;br&gt;
-- Engle &amp;amp; Granger (1987)&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;The two-step procedure we implemented (regress, then test residuals) is their original method. It's simple, intuitive, and remains the most widely used cointegration test for pairs.&lt;/p&gt;

&lt;h3&gt;
  
  
  Johansen (1991): The Multivariate Extension
&lt;/h3&gt;

&lt;p&gt;Soren Johansen's &lt;a href="https://doi.org/10.2307/2938278" rel="noopener noreferrer"&gt;1991 paper&lt;/a&gt; "Estimation and Hypothesis Testing of Cointegration Vectors in Gaussian Vector Autoregressive Models" extended cointegration testing to any number of variables. Instead of running pairwise regressions, Johansen's trace test estimates the rank of the cointegration matrix directly using eigenvalue decomposition.&lt;/p&gt;

&lt;p&gt;For two variables, the Johansen test and Engle-Granger usually agree. For three or more (e.g., a basket of commodity ETFs), Johansen is the only practical option.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Dickey-Fuller Foundation
&lt;/h3&gt;

&lt;p&gt;Both methods ultimately rely on the &lt;a href="https://doi.org/10.2307/2286348" rel="noopener noreferrer"&gt;Augmented Dickey-Fuller test&lt;/a&gt; (Dickey &amp;amp; Fuller, 1979) to detect unit roots. The ADF test fits the model &lt;code&gt;$\Delta y_t = \alpha y_{t-1} + \sum \gamma_i \Delta y_{t-i} + \varepsilon_t$&lt;/code&gt; and tests whether &lt;code&gt;$\alpha = 0$&lt;/code&gt; (unit root) vs &lt;code&gt;$\alpha &amp;lt; 0$&lt;/code&gt; (stationary). The test statistic doesn't follow a standard t-distribution, so special critical values (tabulated by Dickey and Fuller) are needed.&lt;/p&gt;

&lt;h3&gt;
  
  
  Pairs Trading in Practice
&lt;/h3&gt;

&lt;p&gt;The academic foundation for pairs trading was established by &lt;a href="https://doi.org/10.1093/rfs/hhj020" rel="noopener noreferrer"&gt;Gatev, Goetzmann, and Rouwenhorst (2006)&lt;/a&gt; in "Pairs Trading: Performance of a Relative-Value Arbitrage Rule". They analysed pairs trading on US equities from 1962 to 2002 and found average annualised returns of about 11% for the best pairs.&lt;/p&gt;

&lt;p&gt;For a comprehensive treatment, &lt;a href="https://www.wiley.com/en-gb/Pairs+Trading%3A+Quantitative+Methods+and+Analysis-p-9780471460671" rel="noopener noreferrer"&gt;Vidyamurthy (2004)&lt;/a&gt; &lt;em&gt;Pairs Trading: Quantitative Methods and Analysis&lt;/em&gt; covers the full pipeline from pair selection to execution.&lt;/p&gt;

&lt;h3&gt;
  
  
  Further Reading
&lt;/h3&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;The Nobel Prize paper:&lt;/strong&gt; &lt;a href="https://doi.org/10.2307/1913236" rel="noopener noreferrer"&gt;Engle &amp;amp; Granger (1987)&lt;/a&gt;, "Co-Integration and Error Correction", Econometrica&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Multivariate extension:&lt;/strong&gt; &lt;a href="https://doi.org/10.2307/2938278" rel="noopener noreferrer"&gt;Johansen (1991)&lt;/a&gt;, "Estimation and Hypothesis Testing of Cointegration Vectors"&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Unit root foundations:&lt;/strong&gt; &lt;a href="https://doi.org/10.2307/2286348" rel="noopener noreferrer"&gt;Dickey &amp;amp; Fuller (1979)&lt;/a&gt;, "Distribution of the Estimators for Autoregressive Time Series"&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Pairs trading evidence:&lt;/strong&gt; &lt;a href="https://doi.org/10.1093/rfs/hhj020" rel="noopener noreferrer"&gt;Gatev et al. (2006)&lt;/a&gt;, "Pairs Trading: Performance of a Relative-Value Arbitrage Rule"&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Practical guide:&lt;/strong&gt; Vidyamurthy (2004), &lt;em&gt;Pairs Trading: Quantitative Methods and Analysis&lt;/em&gt;, Wiley&lt;/li&gt;
&lt;/ul&gt;

&lt;h3&gt;
  
  
  Interactive Tools
&lt;/h3&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/black-scholes-calculator" rel="noopener noreferrer"&gt;Black-Scholes Calculator&lt;/a&gt; — Price options on the assets in your pairs trades&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/kelly-criterion-calculator" rel="noopener noreferrer"&gt;Kelly Criterion Calculator&lt;/a&gt; — Determine optimal position sizing for your trading strategy&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/drawdown-calculator" rel="noopener noreferrer"&gt;Drawdown Calculator&lt;/a&gt; — Analyse portfolio drawdowns and risk metrics&lt;/li&gt;
&lt;/ul&gt;

&lt;h3&gt;
  
  
  Related Posts
&lt;/h3&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/hidden-markov-models-when-clusters-have-memory" rel="noopener noreferrer"&gt;Hidden Markov Models: When Clusters Have Memory&lt;/a&gt; (regime detection in time series)&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/mcmc-poisson-mixture-earthquake-regimes" rel="noopener noreferrer"&gt;MCMC for Mixture Models: Inferring Earthquake Regimes&lt;/a&gt; (detecting hidden regimes in count data)&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/linear-regression-five-ways" rel="noopener noreferrer"&gt;Linear Regression Five Ways&lt;/a&gt; (the regression foundation that Engle-Granger builds on)&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/maximum-likelihood-estimation-from-scratch" rel="noopener noreferrer"&gt;Maximum Likelihood Estimation from Scratch&lt;/a&gt; (the estimation framework underlying ADF tests)&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Frequently Asked Questions
&lt;/h2&gt;

&lt;h3&gt;
  
  
  What is the difference between correlation and cointegration?
&lt;/h3&gt;

&lt;p&gt;Correlation measures whether two series tend to move in the same direction over short periods. Cointegration tests whether a linear combination of two series is stationary, meaning deviations from their long-run relationship are temporary and self-correcting. Two highly correlated series can drift apart permanently, while two cointegrated series are bound by an equilibrium that pulls them back together.&lt;/p&gt;

&lt;h3&gt;
  
  
  Can cointegration break down over time?
&lt;/h3&gt;

&lt;p&gt;Yes. Cointegration is not permanent. Structural changes in the economy, shifts in industry dynamics, or regulatory events can destroy a previously stable relationship. This is why practitioners re-test cointegration periodically using rolling windows and monitor spread behaviour for signs of regime change.&lt;/p&gt;

&lt;h3&gt;
  
  
  Why does the Engle-Granger test sometimes disagree with the Johansen test?
&lt;/h3&gt;

&lt;p&gt;The two tests use different methodologies. Engle-Granger runs a single regression and tests the residuals, while Johansen uses a vector autoregression framework. They can disagree when cointegration is weak, when the sample period includes structural breaks, or when lag selection differs. Disagreement is usually a warning sign that the relationship is fragile rather than robust.&lt;/p&gt;

&lt;h3&gt;
  
  
  What is a hedge ratio and why does it matter for pairs trading?
&lt;/h3&gt;

&lt;p&gt;The hedge ratio is the coefficient from the cointegrating regression. It tells you how many units of one asset to hold against the other so that the combined position has a stationary spread. Getting the hedge ratio wrong means your spread will drift rather than revert, defeating the purpose of the strategy.&lt;/p&gt;

&lt;h3&gt;
  
  
  Is pairs trading still profitable in modern markets?
&lt;/h3&gt;

&lt;p&gt;Academic evidence suggests that pairs trading returns have declined since the strategy became widely known in the early 2000s. However, it can still be profitable when applied to less liquid markets, when combined with fundamental analysis to select pairs, or when enhanced with more sophisticated signal generation. Transaction costs and execution quality are critical factors.&lt;/p&gt;

&lt;h3&gt;
  
  
  Do I need to difference the price series before testing for cointegration?
&lt;/h3&gt;

&lt;p&gt;No. Cointegration testing specifically requires the original (undifferenced) price series. The whole point is to find a linear combination of non-stationary I(1) series that produces a stationary I(0) result. If you difference first, you remove the very relationship you are trying to detect.&lt;/p&gt;

</description>
      <category>timeseries</category>
      <category>quantfinance</category>
      <category>statistics</category>
    </item>
    <item>
      <title>Cox Proportional Hazards: The Workhorse of Survival Analysis</title>
      <dc:creator>Berkan Sesen</dc:creator>
      <pubDate>Mon, 18 May 2026 07:50:53 +0000</pubDate>
      <link>https://dev.to/berkan_sesen/cox-proportional-hazards-the-workhorse-of-survival-analysis-8an</link>
      <guid>https://dev.to/berkan_sesen/cox-proportional-hazards-the-workhorse-of-survival-analysis-8an</guid>
      <description>&lt;p&gt;Survival analysis starts with a question: how long until an event happens? A patient relapses, a customer churns, a borrower defaults on a loan, a prisoner is rearrested. Parametric models answer by assuming a shape for the hazard — Weibull, log-logistic, exponential — and estimating its parameters. The Cox model sidesteps the entire question. You get hazard ratios, survival curves, and covariate effects without ever specifying what the baseline hazard looks like.&lt;/p&gt;

&lt;p&gt;In our &lt;a href="https://sesen.ai/blog/bayesian-survival-analysis-pymc" rel="noopener noreferrer"&gt;Bayesian survival analysis post&lt;/a&gt;, we used PyMC to fit an accelerated failure time model with explicit distributional assumptions. The Cox model takes the opposite approach: it's semi-parametric, making no assumption about the baseline hazard. This is why it dominates applied survival analysis in medicine, criminal justice, and customer churn modelling.&lt;/p&gt;

&lt;p&gt;By the end of this post, you'll fit a Cox model to real recidivism data, interpret hazard ratios, test the proportional hazards assumption, and extend the model with time-dependent covariates.&lt;/p&gt;

&lt;h2&gt;
  
  
  The Data: Recidivism After Prison
&lt;/h2&gt;

&lt;p&gt;The Rossi recidivism dataset follows 432 male prisoners for one year after their release from prison. The primary question: does receiving financial aid reduce the risk of rearrest?&lt;/p&gt;

&lt;p&gt;Each prisoner has seven baseline covariates (financial aid, age, race, work experience, marital status, parole, prior convictions) plus 52 weekly employment indicators. Of the 432 prisoners, 114 (26%) were rearrested within the year; the remaining 318 were censored (not rearrested during the observation period).&lt;/p&gt;

&lt;p&gt;The Kaplan-Meier curve gives us a first look at the survival pattern:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fifhlvc8zeozys1uinqir.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fifhlvc8zeozys1uinqir.webp" alt="Kaplan-Meier survival curve showing roughly 74% of prisoners avoid rearrest over 52 weeks" width="800" height="489"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;About 74% of prisoners avoid rearrest through the full year. But the curve doesn't tell us which factors predict rearrest. That's where Cox regression comes in.&lt;/p&gt;

&lt;h2&gt;
  
  
  Quick Win: Cox Regression in Action
&lt;/h2&gt;

&lt;p&gt;Let's fit a Cox model and see which covariates matter. Click the badge to run this yourself:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://colab.research.google.com/github/zhubarb/sesen_ai_ml_tutorials/blob/main/notebooks/statistics/cox_proportional_hazards.ipynb" rel="noopener noreferrer"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fcolab.research.google.com%2Fassets%2Fcolab-badge.svg" alt="Open In Colab" width="117" height="20"&gt;&lt;/a&gt;&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;pandas&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;pd&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;lifelines&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;CoxPHFitter&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;KaplanMeierFitter&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;lifelines.datasets&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;load_rossi&lt;/span&gt;

&lt;span class="n"&gt;rossi&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;load_rossi&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;rossi&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; prisoners, &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;rossi&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;arrest&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; rearrested&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;432 prisoners, 114 rearrested
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Fitting the Cox model is one line:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;cph&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;CoxPHFitter&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;cph&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;rossi&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;duration_col&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;week&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;event_col&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;arrest&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;cph&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;print_summary&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The forest plot shows the hazard ratios for each covariate:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fry2904vv03enxhttf0b7.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fry2904vv03enxhttf0b7.webp" alt="Forest plot of hazard ratios: age and financial aid are protective, prior convictions increase risk" width="800" height="492"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Three findings jump out:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Age&lt;/strong&gt; (HR = 0.94, p = 0.01): each additional year of age reduces the hazard of rearrest by 6%. Older prisoners are less likely to reoffend.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Prior convictions&lt;/strong&gt; (HR = 1.10, p &amp;lt; 0.005): each additional prior conviction increases the hazard by 10%. Criminal history is the strongest risk factor.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Financial aid&lt;/strong&gt; (HR = 0.68, p = 0.05): receiving financial aid reduces the hazard by 32%. This is the key finding for the original study, though it's borderline significant.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;The model's concordance is 0.64, meaning it correctly ranks pairs of prisoners by risk 64% of the time.&lt;/p&gt;

&lt;p&gt;Now let's see the effect of financial aid on the survival curve, holding all other covariates at their sample means:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fshgwiiune6hp3i1rf327.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fshgwiiune6hp3i1rf327.webp" alt="Cox-adjusted survival curves showing financial aid recipients have higher survival probability" width="800" height="489"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Prisoners who received financial aid (blue) have a visibly higher survival probability throughout the year. By week 52, the gap is roughly 7 percentage points: 80% vs 73% avoiding rearrest.&lt;/p&gt;

&lt;h2&gt;
  
  
  What Just Happened?
&lt;/h2&gt;

&lt;h3&gt;
  
  
  The Cox Model in One Equation
&lt;/h3&gt;

&lt;p&gt;The Cox proportional hazards model expresses the hazard (instantaneous risk of the event) for individual &lt;code&gt;$i$&lt;/code&gt; at time &lt;code&gt;$t$&lt;/code&gt; as:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Dh_i%28t%29%2520%253D%2520h_0%28t%29%2520%255Ccdot%2520%255Cexp%28%255Cbeta_1%2520x_%257Bi1%257D%2520%252B%2520%255Cbeta_2%2520x_%257Bi2%257D%2520%252B%2520%255Ccdots%2520%252B%2520%255Cbeta_p%2520x_%257Bip%257D%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Dh_i%28t%29%2520%253D%2520h_0%28t%29%2520%255Ccdot%2520%255Cexp%28%255Cbeta_1%2520x_%257Bi1%257D%2520%252B%2520%255Cbeta_2%2520x_%257Bi2%257D%2520%252B%2520%255Ccdots%2520%252B%2520%255Cbeta_p%2520x_%257Bip%257D%29" alt="equation" width="485" height="27"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;where &lt;code&gt;$h_0(t)$&lt;/code&gt; is the &lt;strong&gt;baseline hazard&lt;/strong&gt; (shared by everyone) and the exponential term scales it up or down based on covariates. The key insight: we never need to estimate &lt;code&gt;$h_0(t)$&lt;/code&gt;. Cox's partial likelihood eliminates it entirely, letting us estimate the &lt;code&gt;$\beta$&lt;/code&gt; coefficients from the data alone.&lt;/p&gt;

&lt;h3&gt;
  
  
  Hazard Ratios: The Language of Cox Models
&lt;/h3&gt;

&lt;p&gt;The quantity &lt;code&gt;$\exp(\beta)$&lt;/code&gt; is the &lt;strong&gt;hazard ratio&lt;/strong&gt; (HR). For a binary covariate like financial aid:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;HR &amp;lt; 1 means the covariate is protective (lower hazard)&lt;/li&gt;
&lt;li&gt;HR &amp;gt; 1 means the covariate increases risk&lt;/li&gt;
&lt;li&gt;HR = 1 means no effect&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;For our financial aid variable: HR = 0.68 means that, holding everything else constant, prisoners who received financial aid have 68% of the hazard of those who didn't. Equivalently, financial aid reduces the hazard by 32%.&lt;/p&gt;

&lt;p&gt;For a continuous covariate like age: HR = 0.94 means each additional year of age multiplies the hazard by 0.94. A 30-year-old has &lt;code&gt;$0.94^{10} = 0.54$&lt;/code&gt; times the hazard of a 20-year-old (46% lower risk).&lt;/p&gt;

&lt;h3&gt;
  
  
  The Partial Likelihood Trick
&lt;/h3&gt;

&lt;p&gt;The magic of the Cox model is the &lt;strong&gt;partial likelihood&lt;/strong&gt;, introduced by &lt;a href="https://doi.org/10.1111/j.2517-6161.1972.tb00899.x" rel="noopener noreferrer"&gt;Cox (1972)&lt;/a&gt;. At each event time, we ask: "Given that someone in the risk set was about to fail, what's the probability it was this particular individual?" That probability depends only on the &lt;code&gt;$\beta$&lt;/code&gt; coefficients, not on &lt;code&gt;$h_0(t)$&lt;/code&gt;:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DL%28%255Cbeta%29%2520%253D%2520%255Cprod_%257Bj%253A%2520%255Ctext%257Bevents%257D%257D%2520%255Cfrac%257B%255Cexp%28X_j%2520%255Cbeta%29%257D%257B%255Csum_%257Bk%2520%255Cin%2520R_j%257D%2520%255Cexp%28X_k%2520%255Cbeta%29%257D" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DL%28%255Cbeta%29%2520%253D%2520%255Cprod_%257Bj%253A%2520%255Ctext%257Bevents%257D%257D%2520%255Cfrac%257B%255Cexp%28X_j%2520%255Cbeta%29%257D%257B%255Csum_%257Bk%2520%255Cin%2520R_j%257D%2520%255Cexp%28X_k%2520%255Cbeta%29%257D" alt="equation" width="327" height="68"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;where &lt;code&gt;$R_j$&lt;/code&gt; is the set of individuals still at risk just before time &lt;code&gt;$t_j$&lt;/code&gt;. The baseline hazard cancels out in the ratio. This is what makes the Cox model semi-parametric: parametric in the covariate effects, non-parametric in the baseline hazard.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fvbeqfa0uhle6hr3m1fl3.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fvbeqfa0uhle6hr3m1fl3.webp" alt="Flow diagram showing how the partial likelihood works: at each event time, the risk set shrinks and a probability ratio is computed, then all ratios are multiplied together" width="800" height="446"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;In &lt;code&gt;lifelines&lt;/code&gt;, this is all handled internally:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;cph&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;CoxPHFitter&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;cph&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;rossi&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;duration_col&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;week&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;event_col&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;arrest&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  Reading the Output
&lt;/h3&gt;

&lt;p&gt;The &lt;code&gt;print_summary()&lt;/code&gt; output gives you:&lt;/p&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Column&lt;/th&gt;
&lt;th&gt;Meaning&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;coef&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;The &lt;code&gt;$\beta$&lt;/code&gt; coefficient (log hazard ratio)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;exp(coef)&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;The hazard ratio&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;se(coef)&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;Standard error of &lt;code&gt;$\beta$&lt;/code&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;z&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;Wald test statistic (&lt;code&gt;$\beta / \text{SE}$&lt;/code&gt;)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;p&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;p-value for &lt;code&gt;$H_0: \beta = 0$&lt;/code&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;exp(coef) lower/upper 95%&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;95% CI for the hazard ratio&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;p&gt;A hazard ratio whose 95% CI includes 1.0 is not statistically significant at the 5% level. In the forest plot, covariates in grey (race, work experience, marital status, parole) all span the HR = 1 line.&lt;/p&gt;

&lt;h2&gt;
  
  
  Going Deeper
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Time-Dependent Covariates: When Risk Changes Over Time
&lt;/h3&gt;

&lt;p&gt;The basic Cox model assumes each covariate is fixed at baseline. But what about employment, which changes week to week? The original R code addresses this by expanding the data into &lt;strong&gt;start-stop format&lt;/strong&gt;: one row per person per week, with that week's employment status.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;lifelines&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;CoxTimeVaryingFitter&lt;/span&gt;

&lt;span class="c1"&gt;# Expand to one row per person-week with employment status
# (see notebook for full expansion code)
&lt;/span&gt;&lt;span class="n"&gt;ctv&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;CoxTimeVaryingFitter&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;ctv&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;df_expanded&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;id_col&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;id&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;event_col&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;arrest&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;start_col&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;start&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;stop_col&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;stop&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Employment turns out to be the strongest predictor: HR = 0.35, meaning employed weeks have only 35% of the rearrest hazard compared to unemployed weeks. But the causality is ambiguous: prisoners who avoid rearrest are also more likely to maintain employment. Being in jail prevents you from showing up to work.&lt;/p&gt;

&lt;p&gt;The risk trajectories for three participants illustrate this:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fy7yt17pmtzr8mnym2ln1.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fy7yt17pmtzr8mnym2ln1.webp" alt="Risk trajectories showing how employment status changes predicted risk over time" width="800" height="489"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The never-employed participant (blue, arrested) maintains consistently high risk. The intermittently employed participant (red) shows risk that jumps up and down as employment changes. The mostly-employed participant (green) has consistently low risk.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Proportional Hazards Assumption
&lt;/h3&gt;

&lt;p&gt;The Cox model assumes that the hazard ratio between any two individuals stays &lt;strong&gt;constant over time&lt;/strong&gt;. This is the "proportional" in proportional hazards. If financial aid halves your hazard at week 1, it should also halve it at week 50.&lt;/p&gt;

&lt;p&gt;We test this with Schoenfeld residuals. For each event, the Schoenfeld residual measures how much the covariate value at that event differs from what the model expected. If the residuals trend with time, the proportional hazards assumption is violated.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fpy3sko0uubc3mpmywvzi.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fpy3sko0uubc3mpmywvzi.webp" alt="Schoenfeld residual plots for financial aid, age, and prior convictions with LOWESS smoothers" width="800" height="237"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;For financial aid (left), the LOWESS smoother is roughly flat: the PH assumption holds well (p = 0.98). For age (centre), there's a subtle downward trend: the protective effect of age may weaken over time (p = 0.01, borderline violation). For prior convictions (right), the smoother is close to flat (p = 0.38).&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;lifelines.statistics&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;proportional_hazard_test&lt;/span&gt;
&lt;span class="n"&gt;ph_test&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;proportional_hazard_test&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;cph_reduced&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;df_reduced&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;time_transform&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;rank&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;When the PH assumption fails, options include: stratifying by the offending covariate, adding a time-interaction term, or switching to an accelerated failure time model (as in &lt;a href="https://sesen.ai/blog/bayesian-survival-analysis-pymc" rel="noopener noreferrer"&gt;our Bayesian survival post&lt;/a&gt;).&lt;/p&gt;

&lt;h3&gt;
  
  
  The Baseline Hazard
&lt;/h3&gt;

&lt;p&gt;Although the Cox model doesn't need the baseline hazard to estimate coefficients, we can recover it after fitting. The Breslow estimator gives a non-parametric estimate of the cumulative baseline hazard:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fcfsoa7u2voj1c877hl0z.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fcfsoa7u2voj1c877hl0z.webp" alt="Cumulative baseline hazard increasing roughly linearly over 52 weeks" width="800" height="536"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The roughly linear increase suggests an approximately constant baseline hazard rate: the risk of rearrest per unit time doesn't change much over the year, for someone with average covariates. This is consistent with an exponential baseline (a finding that would validate a Weibull or exponential parametric model, which is exactly what we used in &lt;a href="https://sesen.ai/blog/bayesian-survival-analysis-pymc" rel="noopener noreferrer"&gt;Post 21&lt;/a&gt;).&lt;/p&gt;

&lt;h3&gt;
  
  
  Cox vs Bayesian AFT: When to Use Which
&lt;/h3&gt;

&lt;p&gt;Our &lt;a href="https://sesen.ai/blog/bayesian-survival-analysis-pymc" rel="noopener noreferrer"&gt;Bayesian survival analysis post&lt;/a&gt; used an accelerated failure time (AFT) model with PyMC. How does it compare to Cox?&lt;/p&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Feature&lt;/th&gt;
&lt;th&gt;Cox PH&lt;/th&gt;
&lt;th&gt;Bayesian AFT&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;&lt;strong&gt;Approach&lt;/strong&gt;&lt;/td&gt;
&lt;td&gt;Semi-parametric (no baseline assumption)&lt;/td&gt;
&lt;td&gt;Parametric (Weibull, log-logistic, etc.)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;strong&gt;Interpretation&lt;/strong&gt;&lt;/td&gt;
&lt;td&gt;Hazard ratios: "how much faster does the event happen?"&lt;/td&gt;
&lt;td&gt;Time ratios: "how much longer until the event?"&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;strong&gt;Uncertainty&lt;/strong&gt;&lt;/td&gt;
&lt;td&gt;Frequentist CIs&lt;/td&gt;
&lt;td&gt;Full posterior distributions&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;strong&gt;Flexibility&lt;/strong&gt;&lt;/td&gt;
&lt;td&gt;Time-dependent covariates straightforward&lt;/td&gt;
&lt;td&gt;Hierarchical structure, custom priors&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;strong&gt;Assumption&lt;/strong&gt;&lt;/td&gt;
&lt;td&gt;Proportional hazards (testable)&lt;/td&gt;
&lt;td&gt;Distributional form of baseline hazard&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;p&gt;Use Cox when you want an assumption-light, interpretable analysis that's the standard in your field (medicine, criminal justice). Use Bayesian AFT when you need uncertainty quantification, hierarchical structure, or when the PH assumption fails.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F81agpg61h07h81e5j3nt.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F81agpg61h07h81e5j3nt.webp" alt="Comparison diagram showing Cox PH (semi-parametric, hazard ratios, frequentist CIs) vs Bayesian AFT (fully parametric, time ratios, posterior distributions)" width="800" height="446"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3&gt;
  
  
  Hyperparameter Choices
&lt;/h3&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Parameter&lt;/th&gt;
&lt;th&gt;Value&lt;/th&gt;
&lt;th&gt;Why&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;Full model covariates&lt;/td&gt;
&lt;td&gt;fin, age, race, wexp, mar, paro, prio&lt;/td&gt;
&lt;td&gt;All baseline covariates from the original study&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Reduced model&lt;/td&gt;
&lt;td&gt;fin, age, prio&lt;/td&gt;
&lt;td&gt;Only the statistically significant predictors&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Time-dependent expansion&lt;/td&gt;
&lt;td&gt;52 weekly employment indicators&lt;/td&gt;
&lt;td&gt;One row per person-week for start-stop format&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Confidence level&lt;/td&gt;
&lt;td&gt;95%&lt;/td&gt;
&lt;td&gt;Standard for medical/social science applications&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;PH test transform&lt;/td&gt;
&lt;td&gt;Rank&lt;/td&gt;
&lt;td&gt;More robust than identity transform for discrete event times&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;h2&gt;
  
  
  Where This Comes From
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Cox (1972): The Most-Cited Statistics Paper
&lt;/h3&gt;

&lt;p&gt;David Cox introduced the proportional hazards model in his &lt;a href="https://doi.org/10.1111/j.2517-6161.1972.tb00899.x" rel="noopener noreferrer"&gt;1972 paper&lt;/a&gt; "Regression Models and Life-Tables", published in the Journal of the Royal Statistical Society. With over 50,000 citations, it's one of the most influential statistics papers ever written.&lt;/p&gt;

&lt;p&gt;Cox's insight was that for many survival problems, we care about the &lt;strong&gt;relative&lt;/strong&gt; effect of covariates, not the absolute hazard function. By conditioning on the set of individuals at risk at each event time, the baseline hazard cancels out:&lt;/p&gt;

&lt;blockquote&gt;
&lt;p&gt;"The model is formulated in a very general way and is not restricted to any particular parametric family of distributions."&lt;br&gt;
-- Cox (1972)&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;The partial likelihood was initially controversial. It's not a full likelihood in the classical sense, and the asymptotic theory required new developments. &lt;a href="https://doi.org/10.1214/aos/1176345976" rel="noopener noreferrer"&gt;Andersen and Gill (1982)&lt;/a&gt; provided the rigorous counting-process framework that established the mathematical foundations.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Rossi Dataset
&lt;/h3&gt;

&lt;p&gt;Our dataset comes from &lt;a href="https://doi.org/10.1016/B978-0-12-598240-9.X5001-7" rel="noopener noreferrer"&gt;Rossi, Berk, and Lenihan (1980)&lt;/a&gt;, &lt;em&gt;Money, Work, and Crime: Experimental Evidence&lt;/em&gt;. The study was a randomised experiment: prisoners were randomly assigned to receive financial aid (or not) upon release, then tracked for one year. This experimental design makes the financial aid coefficient more interpretable than typical observational studies, though compliance was imperfect.&lt;/p&gt;

&lt;p&gt;The R analysis we translated follows &lt;a href="https://cran.r-project.org/doc/contrib/Fox-Companion/appendix-cox-regression.pdf" rel="noopener noreferrer"&gt;John Fox's companion chapter&lt;/a&gt; to &lt;em&gt;Applied Regression Analysis&lt;/em&gt;, which has been a standard teaching resource for Cox models in R since 2002.&lt;/p&gt;

&lt;h3&gt;
  
  
  From Cox to Modern Survival Analysis
&lt;/h3&gt;

&lt;p&gt;Cox's framework has been extended in many directions:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Stratified Cox models&lt;/strong&gt;: different baseline hazards for different groups (e.g., men vs women), but shared covariate effects&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Frailty models&lt;/strong&gt;: random effects for unobserved heterogeneity, analogous to &lt;a href="https://sesen.ai/blog/hierarchical-bayesian-regression-pymc" rel="noopener noreferrer"&gt;hierarchical models&lt;/a&gt; in Bayesian statistics&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Time-varying coefficients&lt;/strong&gt;: let &lt;code&gt;$\beta$&lt;/code&gt; change over time, relaxing the PH assumption&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Competing risks&lt;/strong&gt;: multiple possible event types (rearrest vs death vs emigration)&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;For Python practitioners, &lt;a href="https://lifelines.readthedocs.io/" rel="noopener noreferrer"&gt;lifelines&lt;/a&gt; provides a mature, well-documented implementation. For Bayesian alternatives, PyMC can fit both AFT and piecewise-exponential Cox models (see &lt;a href="https://sesen.ai/blog/bayesian-survival-analysis-pymc" rel="noopener noreferrer"&gt;our Post 21&lt;/a&gt;).&lt;/p&gt;

&lt;h3&gt;
  
  
  Further Reading
&lt;/h3&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;The original paper:&lt;/strong&gt; &lt;a href="https://doi.org/10.1111/j.2517-6161.1972.tb00899.x" rel="noopener noreferrer"&gt;Cox (1972)&lt;/a&gt;, "Regression Models and Life-Tables", JRSS Series B&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Mathematical foundations:&lt;/strong&gt; &lt;a href="https://doi.org/10.1214/aos/1176345976" rel="noopener noreferrer"&gt;Andersen &amp;amp; Gill (1982)&lt;/a&gt;, "Cox's Regression Model for Counting Processes"&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;The R companion chapter:&lt;/strong&gt; &lt;a href="https://cran.r-project.org/doc/contrib/Fox-Companion/appendix-cox-regression.pdf" rel="noopener noreferrer"&gt;Fox (2002)&lt;/a&gt;, "Cox Proportional-Hazards Regression for Survival Data"&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Comprehensive textbook:&lt;/strong&gt; Kalbfleisch &amp;amp; Prentice (2002), &lt;em&gt;The Statistical Analysis of Failure Time Data&lt;/em&gt;, Wiley&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Python implementation:&lt;/strong&gt; &lt;a href="https://lifelines.readthedocs.io/" rel="noopener noreferrer"&gt;Davidson-Pilon (2019)&lt;/a&gt;, lifelines documentation&lt;/li&gt;
&lt;/ul&gt;

&lt;h3&gt;
  
  
  Interactive Tools
&lt;/h3&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/kaplan-meier-calculator" rel="noopener noreferrer"&gt;Kaplan-Meier Calculator&lt;/a&gt; — Estimate and compare survival curves interactively&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/medical-stats-calculator" rel="noopener noreferrer"&gt;Medical Statistics Calculator&lt;/a&gt; — Compute diagnostic accuracy metrics for clinical data&lt;/li&gt;
&lt;/ul&gt;

&lt;h3&gt;
  
  
  Related Posts
&lt;/h3&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/bayesian-survival-analysis-pymc" rel="noopener noreferrer"&gt;Bayesian Survival Analysis with PyMC: Modelling Customer Churn&lt;/a&gt; (Bayesian AFT alternative to Cox)&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/hierarchical-bayesian-regression-pymc" rel="noopener noreferrer"&gt;Hierarchical Bayesian Regression with PyMC&lt;/a&gt; (hierarchical modelling, related to frailty models)&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/maximum-likelihood-estimation-from-scratch" rel="noopener noreferrer"&gt;Maximum Likelihood Estimation from Scratch&lt;/a&gt; (the MLE foundation that partial likelihood extends)&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/linear-regression-five-ways" rel="noopener noreferrer"&gt;Linear Regression Five Ways&lt;/a&gt; (regression fundamentals that Cox builds on)&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Frequently Asked Questions
&lt;/h2&gt;

&lt;h3&gt;
  
  
  What does "semi-parametric" mean in the context of the Cox model?
&lt;/h3&gt;

&lt;p&gt;The Cox model is parametric in how covariates affect the hazard (through the exponential term with beta coefficients) but non-parametric in the baseline hazard, which is left completely unspecified. This means you get interpretable covariate effects without having to assume the hazard follows a Weibull, exponential, or any other distribution.&lt;/p&gt;

&lt;h3&gt;
  
  
  How do I interpret a hazard ratio less than 1?
&lt;/h3&gt;

&lt;p&gt;A hazard ratio below 1 means the covariate is protective. For example, a hazard ratio of 0.68 for financial aid means that prisoners receiving aid have 68% of the hazard of those who do not, or equivalently, a 32% reduction in the risk of rearrest at any given time point.&lt;/p&gt;

&lt;h3&gt;
  
  
  What happens if the proportional hazards assumption is violated?
&lt;/h3&gt;

&lt;p&gt;If the hazard ratio between groups changes over time, the Cox model's estimates become averaged effects that may not accurately represent the relationship at any specific time point. Options include stratifying by the offending covariate, adding a time-interaction term, or switching to an accelerated failure time model that does not require proportional hazards.&lt;/p&gt;

&lt;h3&gt;
  
  
  Can the Cox model handle categorical covariates with more than two levels?
&lt;/h3&gt;

&lt;p&gt;Yes. Categorical covariates with k levels are represented using k-1 dummy variables, just as in standard linear regression. Each dummy variable's hazard ratio is interpreted relative to the reference category. The lifelines library handles this automatically when you pass categorical columns.&lt;/p&gt;

&lt;h3&gt;
  
  
  What is the concordance index and what counts as a good value?
&lt;/h3&gt;

&lt;p&gt;The concordance index (C-index) measures how well the model ranks individuals by risk. A value of 0.5 means the model is no better than random, while 1.0 means perfect ranking. In medical and social science applications, values between 0.6 and 0.7 are common and considered acceptable, while values above 0.8 are considered excellent.&lt;/p&gt;

&lt;h3&gt;
  
  
  Why use the Cox model instead of logistic regression for survival data?
&lt;/h3&gt;

&lt;p&gt;Logistic regression ignores the time dimension entirely and cannot handle censored observations properly. If a prisoner was not rearrested during the study but might be rearrested later, logistic regression would treat them as a definitive non-event, wasting information. The Cox model uses the partial follow-up time from censored subjects to improve estimation.&lt;/p&gt;

</description>
      <category>statistics</category>
      <category>supervisedlearning</category>
    </item>
    <item>
      <title>MCMC for Mixture Models: Inferring Earthquake Regimes</title>
      <dc:creator>Berkan Sesen</dc:creator>
      <pubDate>Sat, 16 May 2026 07:55:16 +0000</pubDate>
      <link>https://dev.to/berkan_sesen/mcmc-for-mixture-models-inferring-earthquake-regimes-2c24</link>
      <guid>https://dev.to/berkan_sesen/mcmc-for-mixture-models-inferring-earthquake-regimes-2c24</guid>
      <description>&lt;p&gt;Between 1900 and 2006, the number of major earthquakes per year ranged from 6 to 41. In some decades the planet averaged fewer than 15; in others, closer to 30. That is far too much variation for a single random process. Something changed, and it changed more than once. The histogram tells the story: not one clean bell curve, but two overlapping humps, as if two different Poisson processes were taking turns generating the data.&lt;/p&gt;

&lt;p&gt;The natural question is: are there hidden regimes? Maybe the Earth goes through periods of higher and lower seismic activity, and we're seeing a mixture of two Poisson processes with different rates. We explored mixture models before with &lt;a href="https://sesen.ai/blog/gaussian-mixture-models-em-in-practice" rel="noopener noreferrer"&gt;Gaussian mixtures and EM&lt;/a&gt;, but EM only gives you point estimates. Here, we want the full Bayesian posterior: not just "the rates are probably 16 and 27" but "how uncertain are we about those rates?"&lt;/p&gt;

&lt;p&gt;By the end of this post, you'll build a Metropolis-Hastings sampler from scratch that infers the hidden parameters of a two-component Poisson mixture, giving you posterior distributions for each parameter, complete with uncertainty quantification.&lt;/p&gt;

&lt;h2&gt;
  
  
  The Data: 107 Years of Earthquakes
&lt;/h2&gt;

&lt;p&gt;The dataset records the number of major earthquakes (Richter scale &amp;gt; 7) worldwide for each year from 1900 to 2006. It comes from Zucchini, MacDonald and Langrock's &lt;em&gt;Hidden Markov Models for Time Series&lt;/em&gt; (2016), page 10.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fwe8nlbed4xfap0z6bjka.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fwe8nlbed4xfap0z6bjka.webp" alt="Histogram of major earthquake counts per year, showing a right-skewed distribution with a hint of bimodality" width="800" height="444"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The mean is about 19 earthquakes per year, but the distribution stretches from 6 to 41. That's far more spread than a single Poisson would predict (a Poisson with &lt;code&gt;$\lambda = 19$&lt;/code&gt; has standard deviation &lt;code&gt;$\sqrt{19} \approx 4.4$&lt;/code&gt;, yet the data ranges over 35 units). A two-component mixture is a natural hypothesis: one "quiet" regime and one "active" regime.&lt;/p&gt;

&lt;h2&gt;
  
  
  Quick Win: MCMC in Action
&lt;/h2&gt;

&lt;p&gt;Let's fit this mixture model with Metropolis-Hastings. Click the badge to run it yourself:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://colab.research.google.com/github/zhubarb/sesen_ai_ml_tutorials/blob/main/notebooks/bayesian/mcmc_poisson_mixture.ipynb" rel="noopener noreferrer"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fcolab.research.google.com%2Fassets%2Fcolab-badge.svg" alt="Open In Colab" width="117" height="20"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Watch the MCMC chain explore the posterior. It starts at our initial guess (&lt;code&gt;$\lambda_1 = 10, \lambda_2 = 20$&lt;/code&gt;) and gradually finds the high-probability region:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fhqpqg05alkmgd8aia0ym.gif" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fhqpqg05alkmgd8aia0ym.gif" alt="MCMC chain exploring the posterior over lambda1 and lambda2, starting from initial values and settling into the high-density region" width="800" height="600"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Here's the complete implementation. We need three pieces: the data, a log-likelihood function, and the Metropolis-Hastings sampler.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;scipy.stats&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;poisson&lt;/span&gt;

&lt;span class="c1"&gt;# Major earthquakes per year (Richter &amp;gt; 7), 1900–2006
&lt;/span&gt;&lt;span class="n"&gt;eq&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;array&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;
    &lt;span class="mi"&gt;13&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;14&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;  &lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;16&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;26&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;27&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;18&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;36&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;24&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;22&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;23&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;22&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;18&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;25&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="mi"&gt;21&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;21&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;14&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;  &lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;11&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;14&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;23&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;18&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;17&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;19&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;20&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;22&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;19&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;13&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;26&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;13&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;14&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="mi"&gt;22&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;24&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;21&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;22&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;26&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;21&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;23&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;24&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;27&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;41&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;31&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;27&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;35&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;26&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;28&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;36&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;39&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="mi"&gt;21&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;17&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;22&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;17&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;19&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;15&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;34&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;15&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;22&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;18&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;15&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;20&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;15&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;22&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;19&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;16&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="mi"&gt;30&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;27&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;29&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;23&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;20&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;16&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;21&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;21&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;25&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;16&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;18&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;15&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;18&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;14&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;15&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;  &lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="mi"&gt;15&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;  &lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;11&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;  &lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;  &lt;span class="mi"&gt;7&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;18&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;16&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;13&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;12&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;13&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;20&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;15&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;16&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;12&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;18&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;15&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;16&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="mi"&gt;13&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;15&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;16&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;11&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;11&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The log-likelihood for a two-component Poisson mixture:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;poisson_mix_loglik&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;lam1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;lam2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;d1&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;Log-likelihood: sum_i log[ d1 * Pois(x_i|lam1) + d2 * Pois(x_i|lam2) ]&lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
    &lt;span class="n"&gt;d2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;1.0&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;d1&lt;/span&gt;
    &lt;span class="n"&gt;mix&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;d1&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;poisson&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;pmf&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;lam1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;d2&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;poisson&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;pmf&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;lam2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;mix&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;maximum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mix&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;1e-300&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# protect against log(0)
&lt;/span&gt;    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mix&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Now the Metropolis-Hastings sampler. We propose new parameter values from a multivariate normal centred on the current position, then accept or reject based on the likelihood ratio:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;run_mcmc&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_iter&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;burn_in&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="c1"&gt;# Initial values and proposal covariance (from the original R code)
&lt;/span&gt;    &lt;span class="n"&gt;sigma&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;diag&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.01&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;  &lt;span class="c1"&gt;# proposal covariance
&lt;/span&gt;    &lt;span class="n"&gt;params&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;full&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;n_iter&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;nan&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# lam1, lam2, d1, d2, loglik
&lt;/span&gt;
    &lt;span class="c1"&gt;# Starting point
&lt;/span&gt;    &lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mf"&gt;10.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;20.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.7&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                 &lt;span class="nf"&gt;poisson_mix_loglik&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;10.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;20.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.3&lt;/span&gt;&lt;span class="p"&gt;)]&lt;/span&gt;

    &lt;span class="n"&gt;n_accept&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_iter&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;cur_ll&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;

        &lt;span class="c1"&gt;# Propose: multivariate normal centred on current [lam1, lam2, d1]
&lt;/span&gt;        &lt;span class="n"&gt;current&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
        &lt;span class="n"&gt;prop&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;multivariate_normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mean&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;current&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cov&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;p_lam1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;p_lam2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;p_d1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;prop&lt;/span&gt;

        &lt;span class="c1"&gt;# Enforce constraints: lambdas &amp;gt; 0, 0 &amp;lt; delta1 &amp;lt; 1
&lt;/span&gt;        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;p_lam1&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt; &lt;span class="ow"&gt;or&lt;/span&gt; &lt;span class="n"&gt;p_lam2&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt; &lt;span class="ow"&gt;or&lt;/span&gt; &lt;span class="n"&gt;p_d1&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt; &lt;span class="ow"&gt;or&lt;/span&gt; &lt;span class="n"&gt;p_d1&lt;/span&gt; &lt;span class="o"&gt;&amp;gt;=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="n"&gt;prop_ll&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;inf&lt;/span&gt;
        &lt;span class="k"&gt;else&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="n"&gt;prop_ll&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;poisson_mix_loglik&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;p_lam1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;p_lam2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;p_d1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="c1"&gt;# Accept/reject
&lt;/span&gt;        &lt;span class="n"&gt;log_ratio&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;prop_ll&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;cur_ll&lt;/span&gt;
        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;uniform&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;&lt;/span&gt; &lt;span class="nf"&gt;min&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;0.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;log_ratio&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
            &lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;p_lam1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;p_lam2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;p_d1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;p_d1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;prop_ll&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
            &lt;span class="n"&gt;n_accept&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;
        &lt;span class="k"&gt;else&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;

    &lt;span class="n"&gt;posterior&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;burn_in&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:]&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;posterior&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_accept&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;n_iter&lt;/span&gt;

&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;seed&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;all_params&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;posterior&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;accept_rate&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;run_mcmc&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;eq&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_iter&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;burn_in&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Acceptance rate: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;accept_rate&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="o"&gt;%&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;lambda1: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;posterior&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; ± &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;posterior&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;std&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;lambda2: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;posterior&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; ± &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;posterior&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;std&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;delta1:  &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;posterior&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; ± &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;posterior&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;std&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;Acceptance rate: 31.8%
lambda1: 15.8 ± 0.7
lambda2: 27.1 ± 1.6
delta1:  0.67 ± 0.08
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The sampler found two regimes: a "quiet" regime with &lt;code&gt;$\lambda_1 \approx 16$&lt;/code&gt; earthquakes per year (67% of years) and an "active" regime with &lt;code&gt;$\lambda_2 \approx 27$&lt;/code&gt; per year (33% of years). Here's how the fitted mixture matches the data:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fjm5r87gfgwwppl2qk6we.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fjm5r87gfgwwppl2qk6we.webp" alt="Fitted two-component Poisson mixture overlaid on the empirical histogram, with individual components shown as dashed lines" width="800" height="439"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The mixture captures the right-skew and the heavy tail that a single Poisson misses entirely.&lt;/p&gt;

&lt;h2&gt;
  
  
  What Just Happened?
&lt;/h2&gt;

&lt;h3&gt;
  
  
  The Poisson Mixture Model
&lt;/h3&gt;

&lt;p&gt;We're modelling each year's earthquake count &lt;code&gt;$x_i$&lt;/code&gt; as drawn from one of two Poisson distributions:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DP%28x_i%29%2520%253D%2520%255Cdelta_1%2520%255Ccdot%2520%255Ctext%257BPois%257D%28x_i%2520%255Cmid%2520%255Clambda_1%29%2520%252B%2520%255Cdelta_2%2520%255Ccdot%2520%255Ctext%257BPois%257D%28x_i%2520%255Cmid%2520%255Clambda_2%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DP%28x_i%29%2520%253D%2520%255Cdelta_1%2520%255Ccdot%2520%255Ctext%257BPois%257D%28x_i%2520%255Cmid%2520%255Clambda_1%29%2520%252B%2520%255Cdelta_2%2520%255Ccdot%2520%255Ctext%257BPois%257D%28x_i%2520%255Cmid%2520%255Clambda_2%29" alt="equation" width="455" height="25"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;where &lt;code&gt;$\delta_1 + \delta_2 = 1$&lt;/code&gt;. We don't know which regime generated each observation (that's the "hidden" part). We want to infer three parameters: &lt;code&gt;$\lambda_1$&lt;/code&gt; (quiet-regime rate), &lt;code&gt;$\lambda_2$&lt;/code&gt; (active-regime rate), and &lt;code&gt;$\delta_1$&lt;/code&gt; (the probability a year belongs to the quiet regime). Since &lt;code&gt;$\delta_2 = 1 - \delta_1$&lt;/code&gt;, there are three free parameters.&lt;/p&gt;

&lt;p&gt;The log-likelihood across all 107 years is:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;mix&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;d1&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;poisson&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;pmf&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;lam1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;d2&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;poisson&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;pmf&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;lam2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;loglik&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mix&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;For each data point, we compute the probability under each component, weight by the mixing proportions, and sum the logs. The &lt;code&gt;np.maximum(mix, 1e-300)&lt;/code&gt; guard prevents &lt;code&gt;log(0)&lt;/code&gt; when a proposed parameter makes a data point essentially impossible.&lt;/p&gt;

&lt;h3&gt;
  
  
  Metropolis-Hastings: Propose, Evaluate, Accept/Reject
&lt;/h3&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fhhh4zbtsxufrlgqyz3vd.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fhhh4zbtsxufrlgqyz3vd.webp" alt="The Metropolis-Hastings algorithm flow: propose, check constraints, evaluate likelihood ratio, accept or reject" width="800" height="446"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The sampler follows the same &lt;a href="https://sesen.ai/blog/mcmc-metropolis-hastings-island-hopping-guide" rel="noopener noreferrer"&gt;Metropolis-Hastings logic from our island-hopping post&lt;/a&gt;, but now in a continuous 3D parameter space instead of a discrete one.&lt;/p&gt;

&lt;p&gt;At each iteration:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Propose&lt;/strong&gt; a new point &lt;code&gt;$(\lambda_1', \lambda_2', \delta_1')$&lt;/code&gt; from a multivariate normal centred on the current position, with covariance &lt;code&gt;$\Sigma = \text{diag}(1, 1, 0.01)$&lt;/code&gt;.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Evaluate&lt;/strong&gt; the log-likelihood at the proposed point. If any constraint is violated (&lt;code&gt;$\lambda \leq 0$&lt;/code&gt; or &lt;code&gt;$\delta_1 \notin (0, 1)$&lt;/code&gt;), assign &lt;code&gt;$-\infty$&lt;/code&gt;.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Accept or reject&lt;/strong&gt; with probability &lt;code&gt;$\min(1, r)$&lt;/code&gt; where &lt;code&gt;$r = L(\theta') / L(\theta)$&lt;/code&gt; is the likelihood ratio. In log space:
&lt;/li&gt;
&lt;/ol&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;log_ratio&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;prop_ll&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;cur_ll&lt;/span&gt;
&lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;uniform&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;&lt;/span&gt; &lt;span class="nf"&gt;min&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;0.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;log_ratio&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="c1"&gt;# accept
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This is identical to the acceptance rule in our &lt;a href="https://sesen.ai/blog/from-mle-to-bayesian-inference" rel="noopener noreferrer"&gt;Bayesian inference post&lt;/a&gt;, but applied to a mixture likelihood rather than a single-distribution posterior.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Proposal Distribution
&lt;/h3&gt;

&lt;p&gt;The proposal covariance &lt;code&gt;$\Sigma = \text{diag}(1, 1, 0.01)$&lt;/code&gt; controls the step size. The &lt;code&gt;$\lambda$&lt;/code&gt; parameters get steps of standard deviation 1 (reasonable, since they live around 10-30), while &lt;code&gt;$\delta_1$&lt;/code&gt; gets steps of standard deviation 0.1 (since it's bounded between 0 and 1).&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;sigma&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;diag&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.01&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;span class="n"&gt;prop&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;multivariate_normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mean&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;current&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cov&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;If the steps are too large, most proposals land in low-probability regions and get rejected (acceptance rate near 0%). Too small, and the chain accepts nearly every step but barely moves, taking ages to explore the posterior. The 32% acceptance rate we got is in the sweet spot: Metropolis-Hastings theory suggests 20-50% is ideal for multivariate targets.&lt;/p&gt;

&lt;h3&gt;
  
  
  Burn-In: Forgetting the Starting Point
&lt;/h3&gt;

&lt;p&gt;We initialised at &lt;code&gt;$\lambda_1 = 10, \lambda_2 = 20, \delta_1 = 0.3$&lt;/code&gt;, which is deliberately far from the posterior mode. The first ~100 iterations are the chain migrating from this poor starting point to the high-probability region. We discard these as &lt;strong&gt;burn-in&lt;/strong&gt;.&lt;/p&gt;

&lt;p&gt;The trace plots make this visible:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fvacgy7wuzk05698ntzpa.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fvacgy7wuzk05698ntzpa.webp" alt="MCMC trace plots for lambda1, lambda2, and delta1 showing burn-in migration and stable mixing" width="800" height="551"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;In the grey-shaded burn-in region, you can see all three parameters climbing from their initial values toward the posterior. After burn-in, the chains oscillate around their posterior means: &lt;code&gt;$\lambda_1 \approx 15.8$&lt;/code&gt;, &lt;code&gt;$\lambda_2 \approx 27.1$&lt;/code&gt;, &lt;code&gt;$\delta_1 \approx 0.67$&lt;/code&gt;.&lt;/p&gt;

&lt;h3&gt;
  
  
  Reading the Posterior
&lt;/h3&gt;

&lt;p&gt;Unlike &lt;a href="https://sesen.ai/blog/gaussian-mixture-models-em-in-practice" rel="noopener noreferrer"&gt;EM for Gaussian mixtures&lt;/a&gt;, which gives you point estimates, MCMC gives you the full posterior distribution. The marginal posteriors show the uncertainty in each parameter:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fxx764ri31c68whw9ces1.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fxx764ri31c68whw9ces1.webp" alt="Marginal posterior distributions for lambda1, lambda2, and delta1 with KDE overlays" width="800" height="256"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The 95% credible intervals are:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;code&gt;$\lambda_1 \in (14.3, 17.1)$&lt;/code&gt;: the quiet regime has 14-17 major earthquakes per year&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;$\lambda_2 \in (24.7, 30.8)$&lt;/code&gt;: the active regime has 24-31&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;$\delta_1 \in (0.49, 0.83)$&lt;/code&gt;: the quiet regime accounts for roughly 49-83% of years&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;These intervals are the Bayesian answer to "how sure are we?" EM would give &lt;code&gt;$\lambda_1 = 15.8$&lt;/code&gt; and nothing more. MCMC gives &lt;code&gt;$\lambda_1 = 15.8 \pm 0.7$&lt;/code&gt; and the full shape of the uncertainty.&lt;/p&gt;

&lt;p&gt;The pairs plot reveals how the parameters are correlated in the posterior:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F8jq0geaphadgjeeuhp77.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F8jq0geaphadgjeeuhp77.webp" alt="Posterior pairs plot showing correlations between lambda1, lambda2, and delta1" width="800" height="823"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Notice the positive correlation between &lt;code&gt;$\lambda_1$&lt;/code&gt; and &lt;code&gt;$\lambda_2$&lt;/code&gt;: when the sampler explores higher rates for the quiet regime, it tends to push the active regime higher too, to maintain the overall fit. Similarly, &lt;code&gt;$\delta_1$&lt;/code&gt; correlates positively with &lt;code&gt;$\lambda_1$&lt;/code&gt;, because increasing the quiet-regime rate means it needs to account for a larger share of the data to maintain the same overall mean.&lt;/p&gt;

&lt;h2&gt;
  
  
  Going Deeper
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Why MCMC Instead of EM?
&lt;/h3&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fnrjofyuvxgw3f1j3632a.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fnrjofyuvxgw3f1j3632a.webp" alt="EM vs MCMC for mixture models: EM gives point estimates fast, MCMC gives full posteriors with uncertainty" width="800" height="446"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;We could fit this mixture with EM, which is computationally simpler and faster. So why use MCMC?&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Full posterior, not point estimates.&lt;/strong&gt; EM gives you the maximum likelihood values &lt;code&gt;$(\hat{\lambda}_1, \hat{\lambda}_2, \hat{\delta}_1)$&lt;/code&gt;. MCMC gives you the entire posterior distribution, including uncertainty quantification and credible intervals. For a dataset of only 107 points, that uncertainty matters.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;No closed-form M-step needed.&lt;/strong&gt; EM requires deriving the M-step analytically for each model. For Poisson mixtures, this is straightforward, but for more exotic likelihoods (as in our &lt;a href="https://sesen.ai/blog/one-inflated-beta-regression-pymc" rel="noopener noreferrer"&gt;one-inflated Beta regression post&lt;/a&gt;), EM can be intractable. MCMC only needs to evaluate the likelihood.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Natural uncertainty propagation.&lt;/strong&gt; If you want to predict "what's the probability that next year has more than 30 earthquakes?", you can average over the posterior samples. With EM, you'd need to bootstrap or use asymptotic approximations.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;The trade-off is computational cost. EM converges in a few dozen iterations; our MCMC needs 1,000 iterations and still shows some roughness in the posteriors.&lt;/p&gt;

&lt;h3&gt;
  
  
  Constraint Handling by Rejection
&lt;/h3&gt;

&lt;p&gt;Our parameters have constraints: &lt;code&gt;$\lambda_1, \lambda_2 &amp;gt; 0$&lt;/code&gt; and &lt;code&gt;$0 &amp;lt; \delta_1 &amp;lt; 1$&lt;/code&gt;. Rather than using a constrained sampler or transforming variables (e.g., log-transforming the lambdas), we take the simplest approach: if a proposed point violates any constraint, we set its log-likelihood to &lt;code&gt;$-\infty$&lt;/code&gt;, which guarantees rejection.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;p_lam1&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt; &lt;span class="ow"&gt;or&lt;/span&gt; &lt;span class="n"&gt;p_lam2&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt; &lt;span class="ow"&gt;or&lt;/span&gt; &lt;span class="n"&gt;p_d1&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt; &lt;span class="ow"&gt;or&lt;/span&gt; &lt;span class="n"&gt;p_d1&lt;/span&gt; &lt;span class="o"&gt;&amp;gt;=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;prop_ll&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;inf&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This is equivalent to placing a uniform prior over the valid region: &lt;code&gt;$P(\theta) \propto 1$&lt;/code&gt; for valid &lt;code&gt;$\theta$&lt;/code&gt;, &lt;code&gt;$P(\theta) = 0$&lt;/code&gt; otherwise. It's simple and works well when the posterior is far from the boundaries (as it is here). For parameters near boundaries, a reparameterisation would be more efficient.&lt;/p&gt;

&lt;h3&gt;
  
  
  Hyperparameter Sensitivity
&lt;/h3&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Parameter&lt;/th&gt;
&lt;th&gt;Value&lt;/th&gt;
&lt;th&gt;Why&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;
&lt;code&gt;N&lt;/code&gt; (iterations)&lt;/td&gt;
&lt;td&gt;1,000&lt;/td&gt;
&lt;td&gt;Enough for this simple 3-parameter model. More iterations would give smoother posteriors.&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;burn_in&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;100&lt;/td&gt;
&lt;td&gt;The chain reaches the posterior mode within ~50 iterations; 100 gives a safety margin.&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;sigma&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;diag(1, 1, 0.01)&lt;/td&gt;
&lt;td&gt;Step sizes matched to each parameter's scale. Gives ~32% acceptance rate.&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;
&lt;code&gt;$\lambda_1$&lt;/code&gt; init&lt;/td&gt;
&lt;td&gt;10&lt;/td&gt;
&lt;td&gt;Deliberately low to test that the chain can find the right region.&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;
&lt;code&gt;$\lambda_2$&lt;/code&gt; init&lt;/td&gt;
&lt;td&gt;20&lt;/td&gt;
&lt;td&gt;Close to the overall mean but below the true posterior value.&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;
&lt;code&gt;$\delta_1$&lt;/code&gt; init&lt;/td&gt;
&lt;td&gt;0.3&lt;/td&gt;
&lt;td&gt;Deliberately low (true value ~0.67) to test burn-in.&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;p&gt;The proposal covariance is the most important tuning parameter. Making &lt;code&gt;sigma&lt;/code&gt; too large leads to high rejection rates; too small leads to slow exploration. A diagonal covariance ignores correlations between parameters. More sophisticated samplers (adaptive MH, Hamiltonian MC) tune this automatically.&lt;/p&gt;

&lt;h3&gt;
  
  
  Label Switching: The Mixture Model Trap
&lt;/h3&gt;

&lt;p&gt;There's a subtlety we glossed over: our model has a symmetry. If you swap &lt;code&gt;$\lambda_1 \leftrightarrow \lambda_2$&lt;/code&gt; and &lt;code&gt;$\delta_1 \leftrightarrow \delta_2$&lt;/code&gt;, the likelihood is identical. This is called &lt;strong&gt;label switching&lt;/strong&gt;, and in long MCMC runs, the chain can jump between these two modes, producing a posterior that's a symmetric mixture of both.&lt;/p&gt;

&lt;p&gt;We avoid this here because our initial values and proposal dynamics keep &lt;code&gt;$\lambda_1 &amp;lt; \lambda_2$&lt;/code&gt; throughout the run. For more complex models or longer chains, you'd impose an ordering constraint (e.g., &lt;code&gt;$\lambda_1 &amp;lt; \lambda_2$&lt;/code&gt;) or use post-processing to relabel samples.&lt;/p&gt;

&lt;h3&gt;
  
  
  From Scratch to PyMC
&lt;/h3&gt;

&lt;p&gt;For production work, you wouldn't write your own sampler. A probabilistic programming framework like &lt;a href="https://sesen.ai/blog/hierarchical-bayesian-regression-pymc" rel="noopener noreferrer"&gt;PyMC&lt;/a&gt; handles proposals, tuning, diagnostics, and convergence checking automatically. The same model in PyMC looks like:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;pymc&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;

&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Model&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
    &lt;span class="n"&gt;delta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Dirichlet&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;delta&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;ones&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
    &lt;span class="n"&gt;lambdas&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Gamma&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;lambdas&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.05&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;obs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Mixture&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;obs&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;w&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;delta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                     &lt;span class="n"&gt;comp_dists&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Poisson&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;dist&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;lambdas&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;observed&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;eq&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;trace&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2000&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;PyMC uses the No-U-Turn Sampler (NUTS), a variant of Hamiltonian Monte Carlo that's far more efficient than our random-walk Metropolis. But the core idea is the same: draw samples from the posterior to approximate the full distribution.&lt;/p&gt;

&lt;h2&gt;
  
  
  Where This Comes From
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Metropolis et al. (1953): The Original MCMC Paper
&lt;/h3&gt;

&lt;p&gt;The Metropolis-Hastings algorithm has its roots in the physics of the atomic bomb. &lt;a href="https://doi.org/10.1063/1.1699114" rel="noopener noreferrer"&gt;Metropolis, Rosenbluth, Rosenbluth, Teller, and Teller (1953)&lt;/a&gt; developed the algorithm at Los Alamos to simulate the thermodynamic equilibrium of interacting molecules. Their key insight was that you don't need to compute the normalisation constant of a distribution to sample from it; you only need the ratio of probabilities at two points.&lt;/p&gt;

&lt;p&gt;W.K. Hastings generalised the algorithm in &lt;a href="https://doi.org/10.1093/biomet/57.1.97" rel="noopener noreferrer"&gt;1970&lt;/a&gt; to allow asymmetric proposal distributions (Metropolis' version required symmetric proposals). Our implementation uses a symmetric multivariate normal proposal, so we're technically using the original Metropolis algorithm, a special case of Metropolis-Hastings.&lt;/p&gt;

&lt;p&gt;The acceptance rule we used:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Calpha%2520%253D%2520%255Cmin%255Cleft%281%252C%2520%255Cfrac%257BL%28%255Ctheta%27%29%257D%257BL%28%255Ctheta%29%257D%255Cright%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Calpha%2520%253D%2520%255Cmin%255Cleft%281%252C%2520%255Cfrac%257BL%28%255Ctheta%27%29%257D%257BL%28%255Ctheta%29%257D%255Cright%29" alt="equation" width="207" height="60"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;is the Metropolis ratio. The beauty of it: the normalisation constant cancels out in the ratio, so we never need to compute &lt;code&gt;$\int L(\theta) d\theta$&lt;/code&gt;, which is generally intractable.&lt;/p&gt;

&lt;blockquote&gt;
&lt;p&gt;"The purpose of this paper is to describe a general method, suitable for fast electronic computing machines, of calculating the properties of any substance which may be considered as composed of interacting individual molecules."&lt;br&gt;
-- Metropolis et al. (1953)&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;They were simulating liquid states of hard spheres. Seven decades later, we're using the same algorithm to infer earthquake regimes.&lt;/p&gt;

&lt;h3&gt;
  
  
  Mixture Models and Data Augmentation
&lt;/h3&gt;

&lt;p&gt;The idea of fitting mixture models with MCMC goes back to &lt;a href="https://doi.org/10.1080/01621459.1987.10478458" rel="noopener noreferrer"&gt;Tanner and Wong (1987)&lt;/a&gt;, "The Calculation of Posterior Distributions by Data Augmentation." They showed that by introducing latent variables (which component generated each observation), you can construct a Gibbs sampler (an MCMC algorithm that samples each variable in turn, conditioning on the rest) that alternates between sampling component assignments and sampling parameters.&lt;/p&gt;

&lt;p&gt;Our approach is different: we marginalise out the component assignments (sum over all possible assignments so they no longer appear in the expression) and sample only the parameters &lt;code&gt;$(\lambda_1, \lambda_2, \delta_1)$&lt;/code&gt; directly. This is simpler to implement but means we can't recover which regime each year belongs to (without a separate step). The marginalised likelihood is what makes our log-likelihood function work:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DL%28%255Clambda_1%252C%2520%255Clambda_2%252C%2520%255Cdelta_1%29%2520%253D%2520%255Cprod_%257Bi%253D1%257D%255E%257Bn%257D%2520%255Cleft%255B%2520%255Cdelta_1%2520%255Ccdot%2520%255Ctext%257BPois%257D%28x_i%2520%255Cmid%2520%255Clambda_1%29%2520%252B%2520%281%2520-%2520%255Cdelta_1%29%2520%255Ccdot%2520%255Ctext%257BPois%257D%28x_i%2520%255Cmid%2520%255Clambda_2%29%2520%255Cright%255D" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DL%28%255Clambda_1%252C%2520%255Clambda_2%252C%2520%255Cdelta_1%29%2520%253D%2520%255Cprod_%257Bi%253D1%257D%255E%257Bn%257D%2520%255Cleft%255B%2520%255Cdelta_1%2520%255Ccdot%2520%255Ctext%257BPois%257D%28x_i%2520%255Cmid%2520%255Clambda_1%29%2520%252B%2520%281%2520-%2520%255Cdelta_1%29%2520%255Ccdot%2520%255Ctext%257BPois%257D%28x_i%2520%255Cmid%2520%255Clambda_2%29%2520%255Cright%255D" alt="equation" width="631" height="68"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3&gt;
  
  
  The Earthquake Data
&lt;/h3&gt;

&lt;p&gt;Our dataset comes from &lt;a href="https://www.taylorfrancis.com/books/mono/10.1201/b20790/hidden-markov-models-time-series-walter-zucchini-iain-macdonald-roland-langrock" rel="noopener noreferrer"&gt;Zucchini, MacDonald and Langrock (2016)&lt;/a&gt;, &lt;em&gt;Hidden Markov Models for Time Series: An Introduction Using R&lt;/em&gt;, page 10. They use this same data to motivate Hidden Markov Models, noting that a simple mixture model ignores the temporal structure (whether the regime persists from year to year). An HMM would add transition probabilities between regimes. We explored HMMs in our &lt;a href="https://sesen.ai/blog/hidden-markov-models-when-clusters-have-memory" rel="noopener noreferrer"&gt;Hidden Markov Models post&lt;/a&gt;.&lt;/p&gt;

&lt;h3&gt;
  
  
  Further Reading
&lt;/h3&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;The original MCMC paper:&lt;/strong&gt; &lt;a href="https://doi.org/10.1063/1.1699114" rel="noopener noreferrer"&gt;Metropolis et al. (1953)&lt;/a&gt;, "Equation of State Calculations by Fast Computing Machines"&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;The Hastings generalisation:&lt;/strong&gt; &lt;a href="https://doi.org/10.1093/biomet/57.1.97" rel="noopener noreferrer"&gt;Hastings (1970)&lt;/a&gt;, "Monte Carlo Sampling Methods Using Markov Chains and Their Applications"&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Data augmentation for mixtures:&lt;/strong&gt; &lt;a href="https://doi.org/10.1080/01621459.1987.10478458" rel="noopener noreferrer"&gt;Tanner &amp;amp; Wong (1987)&lt;/a&gt;, "The Calculation of Posterior Distributions by Data Augmentation"&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Comprehensive MCMC reference:&lt;/strong&gt; Robert &amp;amp; Casella (2004), &lt;em&gt;Monte Carlo Statistical Methods&lt;/em&gt;, Springer&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;The earthquake dataset source:&lt;/strong&gt; &lt;a href="https://www.taylorfrancis.com/books/mono/10.1201/b20790/hidden-markov-models-time-series-walter-zucchini-iain-macdonald-roland-langrock" rel="noopener noreferrer"&gt;Zucchini, MacDonald &amp;amp; Langrock (2016)&lt;/a&gt;, &lt;em&gt;Hidden Markov Models for Time Series&lt;/em&gt;, Chapter 1&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Next step:&lt;/strong&gt; Our &lt;a href="https://sesen.ai/blog/hierarchical-bayesian-regression-pymc" rel="noopener noreferrer"&gt;Hierarchical Bayesian Regression with PyMC post&lt;/a&gt; shows how a modern PPL automates the sampling&lt;/li&gt;
&lt;/ul&gt;

&lt;h3&gt;
  
  
  Interactive Tools
&lt;/h3&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/markov-chain-simulator" rel="noopener noreferrer"&gt;Markov Chain Simulator&lt;/a&gt; — Explore Markov chain dynamics and transition matrices interactively&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/distribution-explorer" rel="noopener noreferrer"&gt;Distribution Explorer&lt;/a&gt; — Visualise the Poisson and other distributions used in this mixture model&lt;/li&gt;
&lt;/ul&gt;

&lt;h3&gt;
  
  
  Related Posts
&lt;/h3&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/mcmc-metropolis-hastings-island-hopping-guide" rel="noopener noreferrer"&gt;MCMC Island Hopping: Understanding Metropolis-Hastings&lt;/a&gt; (MH fundamentals on a discrete example)&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/gaussian-mixture-models-em-in-practice" rel="noopener noreferrer"&gt;Gaussian Mixture Models: EM in Practice&lt;/a&gt; (the EM alternative to MCMC for mixture models)&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/from-mle-to-bayesian-inference" rel="noopener noreferrer"&gt;From MLE to Bayesian Inference&lt;/a&gt; (why go Bayesian in the first place)&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/hidden-markov-models-when-clusters-have-memory" rel="noopener noreferrer"&gt;Hidden Markov Models: When Clusters Have Memory&lt;/a&gt; (adding temporal structure to mixture models)&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Frequently Asked Questions
&lt;/h2&gt;

&lt;h3&gt;
  
  
  What is the difference between MCMC and EM for fitting mixture models?
&lt;/h3&gt;

&lt;p&gt;EM (Expectation-Maximisation) gives you point estimates of the mixture parameters by iteratively maximising the likelihood. MCMC gives you the full posterior distribution, including uncertainty quantification through credible intervals. For small datasets like the 107-year earthquake record, knowing how uncertain your estimates are is often more valuable than having a single best guess.&lt;/p&gt;

&lt;h3&gt;
  
  
  How do I choose the proposal covariance in Metropolis-Hastings?
&lt;/h3&gt;

&lt;p&gt;The proposal covariance controls the step size of the random walk. Each diagonal entry should roughly match the scale of the corresponding parameter. A good rule of thumb is to target an acceptance rate between 20% and 50% for multivariate problems. If the acceptance rate is too low, shrink the proposal; if too high, increase it.&lt;/p&gt;

&lt;h3&gt;
  
  
  What does the burn-in period do and how long should it be?
&lt;/h3&gt;

&lt;p&gt;Burn-in discards the initial samples where the chain is migrating from its starting point to the high-probability region of the posterior. Its length depends on how far the initial values are from the posterior mode and how quickly the chain moves. Examining trace plots is the most reliable way to judge whether burn-in is sufficient.&lt;/p&gt;

&lt;h3&gt;
  
  
  Can this approach handle more than two mixture components?
&lt;/h3&gt;

&lt;p&gt;Yes. You would add extra rate parameters and mixing weights, increasing the dimensionality of the parameter space. The Metropolis-Hastings algorithm works the same way, though you may need more iterations and careful tuning of the proposal distribution as dimensionality grows.&lt;/p&gt;

&lt;h3&gt;
  
  
  Why use a Poisson mixture instead of a single Poisson distribution?
&lt;/h3&gt;

&lt;p&gt;A single Poisson distribution has its variance equal to its mean, so it cannot capture the overdispersion seen in the earthquake data. The two-component mixture allows for two distinct rates, naturally producing a wider spread and the bimodal shape visible in the histogram.&lt;/p&gt;

&lt;h3&gt;
  
  
  What is label switching and how does it affect the results?
&lt;/h3&gt;

&lt;p&gt;Label switching occurs because swapping the two components (exchanging their rates and mixing weights) produces an identical likelihood. In long MCMC runs, the chain can jump between these symmetric modes, blurring the posterior. The simplest fix is to impose an ordering constraint such as requiring the first rate to be smaller than the second.&lt;/p&gt;

</description>
      <category>bayesian</category>
      <category>sampling</category>
      <category>probabilistic</category>
    </item>
    <item>
      <title>Q-Learning for Games: Teaching an Agent Tic-Tac-Toe Through Self-Play</title>
      <dc:creator>Berkan Sesen</dc:creator>
      <pubDate>Mon, 11 May 2026 10:07:25 +0000</pubDate>
      <link>https://dev.to/berkan_sesen/q-learning-for-games-teaching-an-agent-tic-tac-toe-through-self-play-3n6d</link>
      <guid>https://dev.to/berkan_sesen/q-learning-for-games-teaching-an-agent-tic-tac-toe-through-self-play-3n6d</guid>
      <description>&lt;p&gt;Tic-tac-toe is a solved game. Any competent adult can force a draw every time. But can an agent figure that out with zero human knowledge? Give two agents a blank board, a few simple rules about wins and losses, and nothing else. No opening theory, no strategy guides, no human games to study. After 100,000 games of fumbling against each other, they discover forks, blocking, and centre-first openings entirely on their own.&lt;/p&gt;

&lt;p&gt;This is Q-learning applied to games. In our &lt;a href="https://sesen.ai/blog/q-learning-frozen-lake-from-scratch" rel="noopener noreferrer"&gt;previous Q-learning post&lt;/a&gt;, the agent navigated a frozen lake alone, learning from its own mistakes. Here, we add an opponent. The agent can't just learn the environment; it must learn to &lt;em&gt;outsmart&lt;/em&gt; another learner who's improving at the same time.&lt;/p&gt;

&lt;p&gt;By the end of this post, you'll build two Q-learning agents that teach each other tic-tac-toe through self-play, and you'll understand why this simple setup discovers remarkably strong strategy.&lt;/p&gt;

&lt;h2&gt;
  
  
  The Problem: Tic-Tac-Toe as an RL Environment
&lt;/h2&gt;

&lt;p&gt;Tic-tac-toe is the simplest non-trivial two-player game. The board has 9 cells, two players alternate placing X and O, and the first to complete a row, column, or diagonal wins. If all cells are filled with no winner, it's a draw.&lt;/p&gt;

&lt;p&gt;As an RL problem:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;State&lt;/strong&gt;: the current board (which cells have X, O, or are empty)&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Actions&lt;/strong&gt;: place your marker on any empty cell&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Reward&lt;/strong&gt;: +1 for winning, -1 for losing, 0 for a draw or an ongoing game&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Transition&lt;/strong&gt;: deterministic (unlike the slippery FrozenLake), but the opponent's move is stochastic from your perspective&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;The state space is manageable: there are at most &lt;code&gt;$3^9 = 19{,}683$&lt;/code&gt; possible board configurations (fewer in practice, since many are unreachable). This makes tabular Q-learning a perfect fit, with no need for &lt;a href="https://sesen.ai/blog/deep-q-networks-experience-replay-target-networks" rel="noopener noreferrer"&gt;neural network function approximation&lt;/a&gt;.&lt;/p&gt;

&lt;h2&gt;
  
  
  Quick Win: Self-Play in Action
&lt;/h2&gt;

&lt;p&gt;Let's see two Q-learning agents teach each other from scratch. Click the badge to run this yourself:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://colab.research.google.com/github/zhubarb/sesen_ai_ml_tutorials/blob/main/notebooks/reinforcement-learning/tic_tac_toe_q_learning.ipynb" rel="noopener noreferrer"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fcolab.research.google.com%2Fassets%2Fcolab-badge.svg" alt="Open In Colab" width="117" height="20"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Watch how the agents' play evolves from random moves (early training) to strategic play (late training):&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fj7g70ecuu7bdsrs6nxzn.gif" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fj7g70ecuu7bdsrs6nxzn.gif" alt="Skill progression over training, from random moves to strategic blocking and winning" width="600" height="700"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Here's the complete implementation. We need three pieces: an environment, an agent, and a self-play training loop.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;random&lt;/span&gt;

&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;TicTacToe&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;Tic-tac-toe environment. Board is a flat array of 9 cells.
    Values: 0=empty, 1=X, -1=O.&lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;state&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;9&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;reset&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;state&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;9&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;copy&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;available_actions&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;where&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;state&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;marker&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;marker&lt;/span&gt;
        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;_is_winner&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
            &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;copy&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;win&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;
        &lt;span class="k"&gt;elif&lt;/span&gt; &lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;available_actions&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;copy&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;draw&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;copy&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;ongoing&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_is_winner&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;b&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
            &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="nf"&gt;abs&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;
            &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="nf"&gt;abs&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;
        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="nf"&gt;abs&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;diag&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;
        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="nf"&gt;abs&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;diag&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fliplr&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;)).&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="bp"&gt;False&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The agent is a standard Q-learner with one key adaptation: Q-values for occupied cells are set to &lt;code&gt;NaN&lt;/code&gt; so the agent never tries to play in a taken position.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;QLearningAgent&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;marker&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;epsilon&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;lr&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                 &lt;span class="n"&gt;gamma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.95&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;final_epsilon&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.05&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;marker&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;marker&lt;/span&gt;       &lt;span class="c1"&gt;# 1 for X, -1 for O
&lt;/span&gt;        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;epsilon&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;epsilon&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;lr&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;lr&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;gamma&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;gamma&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;final_epsilon&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;final_epsilon&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;q_table&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{}&lt;/span&gt;          &lt;span class="c1"&gt;# {tuple(state): np.array(9)}
&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_get_q&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;key&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;tuple&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;key&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;q_table&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;full&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;9&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;nan&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;state&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;0.0&lt;/span&gt;    &lt;span class="c1"&gt;# only empty cells get Q-values
&lt;/span&gt;            &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;q_table&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;q_table&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;pick_action&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;available&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;where&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;state&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;rand&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;epsilon&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;choice&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;available&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;_get_q&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;available_q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[(&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;a&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;available&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
        &lt;span class="n"&gt;max_q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;v&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;available_q&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;best&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;available_q&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="n"&gt;max_q&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;choice&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;best&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;update&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;next_state&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;done&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;_get_q&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;done&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="n"&gt;target&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;
        &lt;span class="k"&gt;else&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="n"&gt;next_q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;_get_q&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;next_state&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="n"&gt;target&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;gamma&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;nanmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;next_q&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;lr&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;target&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Now the self-play training loop. Both agents learn simultaneously, with the loser receiving a -1 reward when the other wins:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;env&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;TicTacToe&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;agent_x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;QLearningAgent&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;marker&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;epsilon&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;lr&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;gamma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.95&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;agent_o&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;QLearningAgent&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;marker&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;epsilon&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;lr&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;gamma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.95&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;eps_decay&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;2.5e-5&lt;/span&gt;

&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;ep&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;100_000&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;state&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;reset&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="n"&gt;agents&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;agent_x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;agent_o&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;random&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;&lt;/span&gt; &lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="n"&gt;agents&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;agent_o&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;agent_x&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;  &lt;span class="c1"&gt;# randomise who goes first
&lt;/span&gt;    &lt;span class="n"&gt;turn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
    &lt;span class="n"&gt;history&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
    &lt;span class="n"&gt;done&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;False&lt;/span&gt;

    &lt;span class="k"&gt;while&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="n"&gt;done&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="n"&gt;agent&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;agents&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;turn&lt;/span&gt; &lt;span class="o"&gt;%&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
        &lt;span class="n"&gt;s&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;copy&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;action&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;agent&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;pick_action&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;next_state&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;done&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;agent&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;marker&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;history&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;agent&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;next_state&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;done&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;done&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="c1"&gt;# winner learns from the final move
&lt;/span&gt;            &lt;span class="n"&gt;agent&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;update&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;next_state&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;done&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="c1"&gt;# loser learns too: propagate -reward to their last move
&lt;/span&gt;            &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;win&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt; &lt;span class="ow"&gt;and&lt;/span&gt; &lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;history&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;&amp;gt;=&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
                &lt;span class="n"&gt;other&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;agents&lt;/span&gt;&lt;span class="p"&gt;[(&lt;/span&gt;&lt;span class="n"&gt;turn&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;%&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
                &lt;span class="n"&gt;prev&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;history&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
                &lt;span class="n"&gt;other&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;update&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;prev&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;prev&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;next_state&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;else&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="n"&gt;agent&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;update&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;next_state&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;done&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="n"&gt;state&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;next_state&lt;/span&gt;
        &lt;span class="n"&gt;turn&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;

    &lt;span class="c1"&gt;# decay epsilon for both agents
&lt;/span&gt;    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;a&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;agent_x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;agent_o&lt;/span&gt;&lt;span class="p"&gt;]:&lt;/span&gt;
        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;epsilon&lt;/span&gt; &lt;span class="o"&gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;final_epsilon&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;epsilon&lt;/span&gt; &lt;span class="o"&gt;-=&lt;/span&gt; &lt;span class="n"&gt;eps_decay&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;After training, both agents win around 85% of games against a random opponent (85% for X, 84% for O):&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fviarlzefel54rvsqjqbf.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fviarlzefel54rvsqjqbf.webp" alt="Both agents win around 85% against a random opponent after 100k episodes of self-play" width="800" height="339"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;You just trained two agents to play tic-tac-toe without teaching them a single strategy. Let's understand how.&lt;/p&gt;

&lt;h2&gt;
  
  
  What Just Happened?
&lt;/h2&gt;

&lt;h3&gt;
  
  
  The Board as State, Cells as Actions
&lt;/h3&gt;

&lt;p&gt;The environment represents the board as a flat array of 9 integers: &lt;code&gt;1&lt;/code&gt; for X, &lt;code&gt;-1&lt;/code&gt; for O, &lt;code&gt;0&lt;/code&gt; for empty. This encoding is compact and makes win detection elegant. A row, column, or diagonal sums to +3 (X wins) or -3 (O wins).&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Check rows, columns, diagonals
&lt;/span&gt;&lt;span class="n"&gt;b&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="nf"&gt;abs&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;    &lt;span class="c1"&gt;# row i
&lt;/span&gt;&lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="nf"&gt;abs&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="c1"&gt;# column i
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The action space is the set of empty cells. Using &lt;code&gt;NaN&lt;/code&gt; for occupied positions in the Q-table means the agent physically cannot select an illegal move, as &lt;code&gt;np.nanmax&lt;/code&gt; ignores &lt;code&gt;NaN&lt;/code&gt; values:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;full&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;9&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;nan&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;state&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;0.0&lt;/span&gt;  &lt;span class="c1"&gt;# only legal moves get Q-values
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  Self-Play: The Opponent is the Curriculum
&lt;/h3&gt;

&lt;p&gt;The key insight of self-play is that both agents improve together. In early training, epsilon (the probability of choosing a random action instead of the greedy one) starts at 1.0, so both play nearly randomly and wins and losses are noisy. As epsilon decays linearly towards 0.05, they exploit what they've learned, and the opponent becomes a tougher challenge.&lt;/p&gt;

&lt;p&gt;This creates an &lt;strong&gt;arms race&lt;/strong&gt;. Watch the training curve:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fox7kvhodn8i88mwqhjg7.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fox7kvhodn8i88mwqhjg7.webp" alt="Self-play training dynamics: draw rate rises from 10% to over 40% as both agents improve" width="800" height="445"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Three things happen as training progresses:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Draw rate rises&lt;/strong&gt; from ~10% to ~42%. Both agents get better at defending, so fewer games end in a clear win.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Win rates equalise&lt;/strong&gt;. X starts with a slight advantage (going first), but by the end, both hover around 30%.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;The transition is sharp&lt;/strong&gt;. Around episode 30,000, epsilon has decayed enough that agents exploit their Q-values more than they explore. The draw rate shoots up.&lt;/li&gt;
&lt;/ol&gt;

&lt;h3&gt;
  
  
  Reward Propagation in Adversarial Games
&lt;/h3&gt;

&lt;p&gt;In single-agent Q-learning (like &lt;a href="https://sesen.ai/blog/q-learning-frozen-lake-from-scratch" rel="noopener noreferrer"&gt;FrozenLake&lt;/a&gt;), the agent updates after every step. In a two-player game, we need an extra mechanism: when one agent wins, the &lt;strong&gt;loser&lt;/strong&gt; must also learn from its last move.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;info&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;win&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt; &lt;span class="ow"&gt;and&lt;/span&gt; &lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;history&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;&amp;gt;=&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;other&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;agents&lt;/span&gt;&lt;span class="p"&gt;[(&lt;/span&gt;&lt;span class="n"&gt;turn&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;%&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="n"&gt;prev&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;history&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="n"&gt;other&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;update&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;prev&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;prev&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;next_state&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The winner gets reward +1. The loser's last move gets -1. This is how the agent learns defensive play: "the move I made two turns ago led to my opponent winning, so that was a bad move."&lt;/p&gt;

&lt;h3&gt;
  
  
  Reading the Q-Values
&lt;/h3&gt;

&lt;p&gt;The Q-table is where the agent's strategy lives. Each entry says: "from this board state, how good is it to play in cell X?" Let's look at three critical situations the agent learned to handle:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fro60e1ws5rr2zzsav7i5.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fro60e1ws5rr2zzsav7i5.webp" alt="Three board states showing the agent's learned Q-values: forking, blocking, and winning" width="800" height="296"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Left panel (Set Up a Fork):&lt;/strong&gt; X has the centre and top-left corner. The agent assigns Q = +0.85 to the bottom-right corner (position 8), which creates a &lt;strong&gt;fork&lt;/strong&gt;: two ways to win that the opponent can't both block. Every other empty cell gets Q = 0.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Centre panel (Block or Lose):&lt;/strong&gt; O has positions 0 and 3, threatening to complete the left column. The Q-values here are all negative except position 6 (Q = 0.00), the blocking move. The agent learned that &lt;em&gt;not&lt;/em&gt; blocking leads to certain defeat. Notice the agent didn't just learn that position 6 is good; it learned that every other option is &lt;em&gt;bad&lt;/em&gt;.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Right panel (Take the Win):&lt;/strong&gt; X has positions 0 and 1, one move away from completing the top row. Position 2 gets Q = +0.81. The agent learned to finish the game when the opportunity is there, rather than play elsewhere.&lt;/p&gt;

&lt;h2&gt;
  
  
  Going Deeper
&lt;/h2&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fb1txyp5983cjyg02965d.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fb1txyp5983cjyg02965d.webp" alt="The self-play training loop: two agents improve simultaneously by playing each other" width="800" height="446"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3&gt;
  
  
  Q-Learning in Games vs Single-Agent Environments
&lt;/h3&gt;

&lt;p&gt;In a single-agent setting like &lt;a href="https://sesen.ai/blog/q-learning-frozen-lake-from-scratch" rel="noopener noreferrer"&gt;FrozenLake&lt;/a&gt; or &lt;a href="https://sesen.ai/blog/value-iteration-q-learning-dynamic-programming-meets-rl" rel="noopener noreferrer"&gt;Value Iteration on a grid world&lt;/a&gt;, the environment is stationary. The transition probabilities don't change. In a game with self-play, the "environment" includes the opponent, and the opponent is changing constantly.&lt;/p&gt;

&lt;p&gt;This means Q-learning in games violates a core assumption: stationarity. The Markov property still holds (the board state contains all relevant information), but the transition dynamics shift as the opponent improves. In practice, this works because both agents improve gradually, and the learning rate is high enough to track the changing opponent.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Learning Rate = 1 Choice
&lt;/h3&gt;

&lt;p&gt;You might have noticed &lt;code&gt;lr=1.0&lt;/code&gt;, which seems aggressive. With &lt;code&gt;$\alpha = 1$&lt;/code&gt;, each Q-update completely replaces the old value:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DQ%28s%252C%2520a%29%2520%255Cleftarrow%2520r%2520%252B%2520%255Cgamma%2520%255Cmax_%257Ba%27%257D%2520Q%28s%27%252C%2520a%27%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DQ%28s%252C%2520a%29%2520%255Cleftarrow%2520r%2520%252B%2520%255Cgamma%2520%255Cmax_%257Ba%27%257D%2520Q%28s%27%252C%2520a%27%29" alt="equation" width="305" height="36"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;This works for tic-tac-toe because the game is &lt;strong&gt;deterministic&lt;/strong&gt;: from a given board state, taking a specific action always produces the same next state (your move is deterministic; only the opponent's response varies). With &lt;code&gt;$\alpha = 1$&lt;/code&gt;, the agent always uses the most recent outcome, which adapts quickly to the opponent's evolving strategy.&lt;/p&gt;

&lt;p&gt;For stochastic environments, &lt;code&gt;$\alpha = 1$&lt;/code&gt; would be catastrophic, as it would forget everything from past experience. But for deterministic transitions in a game, it's ideal.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Self-Play Arms Race
&lt;/h3&gt;

&lt;p&gt;Self-play training has a characteristic signature: the draw rate is a proxy for skill. When two beginners play, most games end in wins (because both make exploitable mistakes). When two experts play, most games end in draws (because neither makes a mistake worth exploiting).&lt;/p&gt;

&lt;p&gt;Tic-tac-toe with perfect play from both sides is provably a draw. Our agents' ~42% draw rate suggests they're strong but not perfect: they're still occasionally making mistakes that the opponent can exploit.&lt;/p&gt;

&lt;h3&gt;
  
  
  Hyperparameter Sensitivity
&lt;/h3&gt;

&lt;p&gt;The original code uses these values, all from the &lt;a href="https://github.com/zhubarb/sesen_ai_ml_tutorials" rel="noopener noreferrer"&gt;source implementation&lt;/a&gt;:&lt;/p&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Parameter&lt;/th&gt;
&lt;th&gt;Value&lt;/th&gt;
&lt;th&gt;Why&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;gamma&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;0.95&lt;/td&gt;
&lt;td&gt;Games are short (5-9 moves), so moderate discounting works. Higher values (0.99) also work.&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;lr&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;1.0&lt;/td&gt;
&lt;td&gt;Deterministic transitions; always use the latest outcome.&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;epsilon&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;1.0 to 0.05&lt;/td&gt;
&lt;td&gt;Start fully random, end mostly greedy.&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;eps_decay&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;2.5e-5&lt;/td&gt;
&lt;td&gt;Linear decay over ~38,000 episodes to reach &lt;code&gt;final_epsilon&lt;/code&gt;.&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;episodes&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;100,000&lt;/td&gt;
&lt;td&gt;Enough for the Q-table to converge on the ~6,600 reachable states.&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;p&gt;The Q-table ends up with roughly 6,600 entries (out of the theoretical 19,683 board configurations). Many configurations are unreachable in valid play (e.g., a board where X has played 5 times but O has played once).&lt;/p&gt;

&lt;h3&gt;
  
  
  When NOT to Use Tabular Q-Learning for Games
&lt;/h3&gt;

&lt;p&gt;Tabular Q-learning works beautifully for tic-tac-toe because the state space is tiny. It fails for:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Chess&lt;/strong&gt; (&lt;code&gt;$\sim 10^{44}$&lt;/code&gt; legal positions): the Q-table would be impossibly large&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Go&lt;/strong&gt; (&lt;code&gt;$\sim 10^{170}$&lt;/code&gt;): even worse&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Games with continuous state spaces&lt;/strong&gt;: no table can hold them&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;For these, you need function approximation: &lt;a href="https://sesen.ai/blog/deep-q-networks-experience-replay-target-networks" rel="noopener noreferrer"&gt;deep Q-networks&lt;/a&gt; replace the table with a neural network, or &lt;a href="https://sesen.ai/blog/policy-gradients-reinforce-from-scratch" rel="noopener noreferrer"&gt;policy gradient methods&lt;/a&gt; learn a policy directly. The ideas from this post (self-play, reward propagation, exploration) carry forward directly.&lt;/p&gt;

&lt;h3&gt;
  
  
  Comparison: Self-Play vs Teacher
&lt;/h3&gt;

&lt;p&gt;Our implementation uses self-play: both agents learn simultaneously. An alternative approach (also in the original code) trains against a &lt;strong&gt;teacher&lt;/strong&gt;, a heuristic opponent that plays well but not perfectly. Self-play has the advantage of being curriculum-free: you don't need to design a teacher, and the difficulty automatically scales with the learner's ability. The downside is that training can be unstable early on, as the quality of the training signal depends on having a reasonable opponent.&lt;/p&gt;

&lt;h2&gt;
  
  
  Where This Comes From
&lt;/h2&gt;

&lt;h3&gt;
  
  
  The Roots: Watkins and Temporal Difference Learning
&lt;/h3&gt;

&lt;p&gt;Q-learning was introduced by &lt;a href="https://www.cs.rhul.ac.uk/~chrisw/new_thesis.pdf" rel="noopener noreferrer"&gt;Chris Watkins in his 1989 PhD thesis&lt;/a&gt;, "Learning from Delayed Rewards." The core idea is that an agent can learn the value of actions without knowing the environment's dynamics, purely from the reward signal and the temporal difference between consecutive estimates.&lt;/p&gt;

&lt;p&gt;The update rule we used is exactly Watkins' formulation:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DQ%28s_t%252C%2520a_t%29%2520%255Cleftarrow%2520Q%28s_t%252C%2520a_t%29%2520%252B%2520%255Calpha%2520%255Cleft%255B%2520r_%257Bt%252B1%257D%2520%252B%2520%255Cgamma%2520%255Cmax_%257Ba%257D%2520Q%28s_%257Bt%252B1%257D%252C%2520a%29%2520-%2520Q%28s_t%252C%2520a_t%29%2520%255Cright%255D" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DQ%28s_t%252C%2520a_t%29%2520%255Cleftarrow%2520Q%28s_t%252C%2520a_t%29%2520%252B%2520%255Calpha%2520%255Cleft%255B%2520r_%257Bt%252B1%257D%2520%252B%2520%255Cgamma%2520%255Cmax_%257Ba%257D%2520Q%28s_%257Bt%252B1%257D%252C%2520a%29%2520-%2520Q%28s_t%252C%2520a_t%29%2520%255Cright%255D" alt="equation" width="643" height="46"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The term in brackets is the &lt;strong&gt;TD error&lt;/strong&gt;: the difference between what we expected (&lt;code&gt;$Q(s_t, a_t)$&lt;/code&gt;) and what we actually observed (&lt;code&gt;$r_{t+1} + \gamma \max_a Q(s_{t+1}, a)$&lt;/code&gt;). Learning adjusts Q towards the observed value.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://link.springer.com/article/10.1007/BF00992698" rel="noopener noreferrer"&gt;Watkins and Dayan (1992)&lt;/a&gt; later proved that Q-learning converges to optimal Q-values under certain conditions: every state-action pair must be visited infinitely often, and the learning rate must satisfy the Robbins-Monro conditions (&lt;code&gt;$\sum \alpha = \infty$&lt;/code&gt;, &lt;code&gt;$\sum \alpha^2 &amp;lt; \infty$&lt;/code&gt;). Our &lt;code&gt;$\alpha = 1$&lt;/code&gt; technically violates these conditions, but the deterministic nature of tic-tac-toe means the algorithm still converges in practice.&lt;/p&gt;

&lt;h3&gt;
  
  
  Game-Playing AI: A Brief History
&lt;/h3&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fjq556tds0dmfy1uoi4y7.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fjq556tds0dmfy1uoi4y7.webp" alt="Timeline of game-playing AI: from Samuel's checkers (1959) to AlphaGo (2016)" width="800" height="446"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Games have been the proving ground for AI since the field's inception. Sutton and Barto open Chapter 1 of &lt;a href="http://incompleteideas.net/book/the-book.html" rel="noopener noreferrer"&gt;Reinforcement Learning: An Introduction&lt;/a&gt; with exactly this problem: a temporal-difference learner playing tic-tac-toe. They use it to introduce the core RL concepts before any formal machinery.&lt;/p&gt;

&lt;p&gt;The lineage of game-playing RL runs deep:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Samuel (1959)&lt;/strong&gt;: Arthur Samuel's checkers program was one of the first learning programs, using a form of temporal difference learning decades before the name existed. It beat its creator.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Tesauro (1995)&lt;/strong&gt;: Gerald Tesauro's &lt;a href="https://bkgm.com/articles/tesauro/tdl.html" rel="noopener noreferrer"&gt;TD-Gammon&lt;/a&gt; used temporal difference learning with a neural network to play backgammon at world-champion level. It discovered novel strategies that human experts later adopted.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Silver et al. (2016)&lt;/strong&gt;: &lt;a href="https://www.nature.com/articles/nature16961" rel="noopener noreferrer"&gt;AlphaGo&lt;/a&gt; combined deep neural networks with Monte Carlo tree search and self-play to defeat the world Go champion. The self-play idea is the same as ours; only the scale is different.&lt;/li&gt;
&lt;/ul&gt;

&lt;blockquote&gt;
&lt;p&gt;"The game of tic-tac-toe is a simple example, but it illustrates the fundamental principles of reinforcement learning: learning from interaction, temporal difference methods, and the trade-off between exploration and exploitation."&lt;br&gt;
-- Sutton &amp;amp; Barto, &lt;em&gt;Reinforcement Learning: An Introduction&lt;/em&gt; (2018), Chapter 1&lt;/p&gt;
&lt;/blockquote&gt;

&lt;h3&gt;
  
  
  Connection to Minimax
&lt;/h3&gt;

&lt;p&gt;For a two-player, zero-sum game like tic-tac-toe, optimal play follows the &lt;strong&gt;minimax&lt;/strong&gt; principle: each player assumes the opponent plays optimally and chooses the action that maximises the minimum possible outcome.&lt;/p&gt;

&lt;p&gt;Q-learning with self-play implicitly converges towards minimax values. When both agents are learning optimally, the Q-values for X represent &lt;code&gt;$\max$&lt;/code&gt; (X wants to maximise its reward) and the Q-values for O represent &lt;code&gt;$\min$&lt;/code&gt; (O wants to minimise X's reward, which is equivalent to maximising O's own). The self-play training process, where both agents simultaneously improve, pushes the Q-values towards this minimax equilibrium.&lt;/p&gt;

&lt;p&gt;This is why our agents discover strong strategy without being told about minimax: the competitive pressure of self-play naturally drives them there.&lt;/p&gt;

&lt;h3&gt;
  
  
  Further Reading
&lt;/h3&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;The original thesis:&lt;/strong&gt; &lt;a href="https://www.cs.rhul.ac.uk/~chrisw/new_thesis.pdf" rel="noopener noreferrer"&gt;Watkins (1989) "Learning from Delayed Rewards"&lt;/a&gt;, Sections 3-4 for the Q-learning algorithm&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;The convergence proof:&lt;/strong&gt; &lt;a href="https://link.springer.com/article/10.1007/BF00992698" rel="noopener noreferrer"&gt;Watkins &amp;amp; Dayan (1992)&lt;/a&gt; in Machine Learning&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;The RL textbook:&lt;/strong&gt; &lt;a href="http://incompleteideas.net/book/the-book.html" rel="noopener noreferrer"&gt;Sutton &amp;amp; Barto (2018)&lt;/a&gt;, Chapter 1 (tic-tac-toe example) and Chapter 6 (TD learning)&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Self-play at scale:&lt;/strong&gt; &lt;a href="https://arxiv.org/abs/1712.01815" rel="noopener noreferrer"&gt;Silver et al. (2017) "Mastering Chess and Shogi by Self-Play"&lt;/a&gt;, the AlphaZero paper&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Next step:&lt;/strong&gt; Our &lt;a href="https://sesen.ai/blog/deep-q-networks-experience-replay-target-networks" rel="noopener noreferrer"&gt;DQN post&lt;/a&gt; shows how to replace the Q-table with a neural network for environments too large for tables&lt;/li&gt;
&lt;/ul&gt;

&lt;h3&gt;
  
  
  Interactive Tools
&lt;/h3&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/q-learning-visualizer" rel="noopener noreferrer"&gt;Q-Learning Visualiser&lt;/a&gt; — Watch Q-learning train step-by-step on grid worlds in the browser&lt;/li&gt;
&lt;/ul&gt;

&lt;h3&gt;
  
  
  Related Posts
&lt;/h3&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/q-learning-frozen-lake-from-scratch" rel="noopener noreferrer"&gt;Q-Learning from Scratch: Navigating the Frozen Lake&lt;/a&gt; (tabular Q-learning fundamentals)&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/value-iteration-q-learning-dynamic-programming-meets-rl" rel="noopener noreferrer"&gt;Value Iteration vs Q-Learning: Dynamic Programming Meets RL&lt;/a&gt; (comparing model-based and model-free approaches)&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/deep-q-networks-experience-replay-target-networks" rel="noopener noreferrer"&gt;Deep Q-Networks: When Tables Aren't Enough&lt;/a&gt; (scaling Q-learning with neural networks)&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/policy-gradients-reinforce-from-scratch" rel="noopener noreferrer"&gt;Policy Gradients and REINFORCE from Scratch&lt;/a&gt; (an alternative to Q-learning that learns a policy directly)&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Frequently Asked Questions
&lt;/h2&gt;

&lt;h3&gt;
  
  
  What is Q-learning with self-play?
&lt;/h3&gt;

&lt;p&gt;Q-learning is a reinforcement learning algorithm that learns the value of each state-action pair by interacting with an environment. Self-play means both players are Q-learning agents training against each other. As each agent improves, it forces the other to improve too, driving both towards optimal play without needing a hand-crafted opponent.&lt;/p&gt;

&lt;h3&gt;
  
  
  Why use self-play instead of training against a fixed opponent?
&lt;/h3&gt;

&lt;p&gt;A fixed opponent (random or rule-based) has a ceiling: once your agent exploits its weaknesses, it stops improving. Self-play creates an ever-improving curriculum because the opponent adapts alongside the learner. This naturally pushes both agents towards minimax-optimal strategies.&lt;/p&gt;

&lt;h3&gt;
  
  
  How does epsilon affect self-play training?
&lt;/h3&gt;

&lt;p&gt;Epsilon controls how often the agent takes a random action instead of its current best. Too low and the agents settle into a narrow set of positions, missing better strategies. Too high and learning is slow because actions are mostly random. Decaying epsilon over time (high early, low late) gives broad exploration first, then refined exploitation.&lt;/p&gt;

&lt;h3&gt;
  
  
  Does Q-learning with self-play always converge to optimal play in tic-tac-toe?
&lt;/h3&gt;

&lt;p&gt;Yes, given enough training episodes and appropriate hyperparameters. Tic-tac-toe has a small enough state space (under 6,000 reachable positions) that tabular Q-learning can visit every state-action pair many times. The Q-values converge to the minimax equilibrium, where both agents play perfectly and every game ends in a draw.&lt;/p&gt;

&lt;h3&gt;
  
  
  Can this approach scale to more complex games like chess or Go?
&lt;/h3&gt;

&lt;p&gt;Not with a Q-table. Chess has roughly &lt;code&gt;$10^{47}$&lt;/code&gt; positions, making tabular Q-learning impossible. For complex games, you replace the table with a neural network (Deep Q-Networks) or use policy gradient methods. AlphaGo and AlphaZero used self-play with deep neural networks and Monte Carlo tree search to master Go, chess, and shogi.&lt;/p&gt;

&lt;h3&gt;
  
  
  What is the difference between Q-learning and minimax for game playing?
&lt;/h3&gt;

&lt;p&gt;Minimax requires a complete model of the game (all possible states and transitions) and searches the full game tree. Q-learning is model-free: it learns from experience without needing the game rules explicitly. For small games like tic-tac-toe both reach the same optimal strategy, but Q-learning generalises to environments where you cannot enumerate the full game tree.&lt;/p&gt;

</description>
      <category>reinforcementlearning</category>
      <category>gametheory</category>
    </item>
    <item>
      <title>Value Iteration vs Q-Learning: Dynamic Programming Meets RL</title>
      <dc:creator>Berkan Sesen</dc:creator>
      <pubDate>Mon, 04 May 2026 13:07:14 +0000</pubDate>
      <link>https://dev.to/berkan_sesen/value-iteration-vs-q-learning-dynamic-programming-meets-rl-3b3a</link>
      <guid>https://dev.to/berkan_sesen/value-iteration-vs-q-learning-dynamic-programming-meets-rl-3b3a</guid>
      <description>&lt;p&gt;You have a map of the frozen lake. Every crack in the ice, every slippery patch, every hole is marked. You can sit at your desk and plan the perfect route before stepping foot on the ice. That is value iteration.&lt;/p&gt;

&lt;p&gt;Now imagine you have no map. You lace up your boots and start walking. You slip, you fall into holes, you backtrack. But each time you learn a little more about which moves pay off and which ones do not. That is Q-learning.&lt;/p&gt;

&lt;p&gt;Both approaches solve the same problem (finding the best policy in a Markov Decision Process), but they start from radically different assumptions about what you know. In &lt;a href="https://sesen.ai/blog/q-learning-frozen-lake-from-scratch" rel="noopener noreferrer"&gt;our earlier Q-learning post&lt;/a&gt;, we focused purely on the model-free approach. This post puts the two side by side on the same FrozenLake environment, so you can see exactly what a model buys you, and what you give up when you do not have one.&lt;/p&gt;

&lt;p&gt;By the end of this post, you will have implemented both value iteration and Q-learning from scratch, compared their convergence and policies head-to-head, and understood the Bellman equation that underpins them both.&lt;/p&gt;

&lt;h2&gt;
  
  
  Quick Win: Run Both Algorithms
&lt;/h2&gt;

&lt;p&gt;Let's see both algorithms in action. Click the badge to open the interactive notebook:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://colab.research.google.com/github/zhubarb/sesen_ai_ml_tutorials/blob/main/notebooks/reinforcement-learning/value_iteration_vs_q_learning.ipynb" rel="noopener noreferrer"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fcolab.research.google.com%2Fassets%2Fcolab-badge.svg" alt="Open In Colab" width="117" height="20"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Watch value iteration discover optimal state values in just a few sweeps, with "heat" radiating outward from the goal:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F8wzvlim0b9hdiq883zfn.gif" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F8wzvlim0b9hdiq883zfn.gif" alt="Value iteration evolving state values over sweeps, with values radiating outward from the goal state as the algorithm converges." width="600" height="600"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Here is the complete implementation for both methods:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;gymnasium&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;

&lt;span class="c1"&gt;# ── Value Iteration (model-based) ──────────────────────────
&lt;/span&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;value_iteration&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;gamma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.95&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;theta&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e-8&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;Compute optimal V* using the Bellman optimality equation.&lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
    &lt;span class="n"&gt;nS&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;observation_space&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;
    &lt;span class="n"&gt;nA&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;action_space&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;
    &lt;span class="n"&gt;V&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nS&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;while&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="n"&gt;delta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
        &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;s&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nS&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
            &lt;span class="n"&gt;action_values&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nA&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;a&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nA&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
                &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;prob&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;next_s&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;done&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unwrapped&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;P&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;]:&lt;/span&gt;
                    &lt;span class="n"&gt;action_values&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;prob&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;reward&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;gamma&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;V&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;next_s&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
            &lt;span class="n"&gt;best_value&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;action_values&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="n"&gt;delta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;delta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nf"&gt;abs&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;best_value&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;V&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="p"&gt;]))&lt;/span&gt;
            &lt;span class="n"&gt;V&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;best_value&lt;/span&gt;
        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;delta&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;&lt;/span&gt; &lt;span class="n"&gt;theta&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="k"&gt;break&lt;/span&gt;

    &lt;span class="c1"&gt;# Extract greedy policy from V*
&lt;/span&gt;    &lt;span class="n"&gt;policy&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nS&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;s&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nS&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;action_values&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nA&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;a&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nA&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
            &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;prob&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;next_s&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;done&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unwrapped&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;P&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;]:&lt;/span&gt;
                &lt;span class="n"&gt;action_values&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;prob&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;reward&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;gamma&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;V&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;next_s&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
        &lt;span class="n"&gt;policy&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;argmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;action_values&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;V&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;policy&lt;/span&gt;

&lt;span class="c1"&gt;# ── Q-Learning (model-free) ────────────────────────────────
&lt;/span&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;q_learning&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_episodes&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;10000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;gamma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.95&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;lr&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
               &lt;span class="n"&gt;epsilon_start&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;epsilon_end&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.01&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;decay_rate&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;7e-3&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;Tabular Q-learning with epsilon-greedy exploration.&lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
    &lt;span class="n"&gt;nS&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;observation_space&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;
    &lt;span class="n"&gt;nA&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;action_space&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;
    &lt;span class="n"&gt;Q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;nS&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;nA&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
    &lt;span class="n"&gt;rewards&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
    &lt;span class="n"&gt;epsilon&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;epsilon_start&lt;/span&gt;

    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;ep&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_episodes&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;reset&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;t&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
            &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;uniform&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;=&lt;/span&gt; &lt;span class="n"&gt;epsilon&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
                &lt;span class="n"&gt;action&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;action_space&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sample&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
            &lt;span class="k"&gt;else&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
                &lt;span class="n"&gt;action&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;argmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Q&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;:])&lt;/span&gt;

            &lt;span class="n"&gt;next_state&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="c1"&gt;# Q-learning update
&lt;/span&gt;            &lt;span class="n"&gt;Q&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;lr&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;
                &lt;span class="n"&gt;reward&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;gamma&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Q&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;next_state&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;:])&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;Q&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
            &lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="n"&gt;state&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;next_state&lt;/span&gt;
            &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt; &lt;span class="ow"&gt;or&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
                &lt;span class="n"&gt;rewards&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
                &lt;span class="n"&gt;epsilon&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;epsilon_end&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;epsilon_start&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;epsilon_end&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;decay_rate&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;ep&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
                &lt;span class="k"&gt;break&lt;/span&gt;

    &lt;span class="n"&gt;policy&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;argmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;Q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;policy&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;rewards&lt;/span&gt;

&lt;span class="c1"&gt;# ── Run both on FrozenLake ─────────────────────────────────
&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;make&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;FrozenLake-v1&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;is_slippery&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;V_star&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;vi_policy&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;value_iteration&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;gamma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.95&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;Q_star&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ql_policy&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ql_rewards&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;q_learning&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_episodes&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;10000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;gamma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.95&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;strong&gt;The result:&lt;/strong&gt; Value iteration converges in 184 sweeps and produces a policy that succeeds ~73% of the time. Q-learning, after 10,000 episodes of trial and error, learns a policy that also achieves ~73% success, and agrees with the VI policy on 14 out of 16 states. Both methods find near-identical strategies, but through very different paths.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fr71y8gyfne8bqvrv9lsj.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fr71y8gyfne8bqvrv9lsj.webp" alt="Side-by-side comparison of learned policies, with arrows showing the greedy action in each state and colour showing the state value. Both methods converge to nearly identical policies." width="800" height="414"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h2&gt;
  
  
  What Just Happened?
&lt;/h2&gt;

&lt;p&gt;Both algorithms answer the same question: "What is the best action in every state?" But they go about it in fundamentally different ways.&lt;/p&gt;

&lt;h3&gt;
  
  
  Value Iteration: Planning with a Blueprint
&lt;/h3&gt;

&lt;p&gt;Value iteration has access to the environment's full transition model &lt;code&gt;$P(s' \mid s, a)$&lt;/code&gt;. This is the complete blueprint: for every state and action, you know exactly which states you might land in and with what probability.&lt;/p&gt;

&lt;p&gt;The algorithm sweeps through every state, computing the value of the best action using the &lt;strong&gt;Bellman optimality equation&lt;/strong&gt;:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DV%28s%29%2520%255Cleftarrow%2520%255Cmax_a%2520%255Csum_%257Bs%27%257D%2520P%28s%27%2520%255Cmid%2520s%252C%2520a%29%2520%255Cleft%255B%2520R%28s%252C%2520a%252C%2520s%27%29%2520%252B%2520%255Cgamma%2520%255C%252C%2520V%28s%27%29%2520%255Cright%255D" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DV%28s%29%2520%255Cleftarrow%2520%255Cmax_a%2520%255Csum_%257Bs%27%257D%2520P%28s%27%2520%255Cmid%2520s%252C%2520a%29%2520%255Cleft%255B%2520R%28s%252C%2520a%252C%2520s%27%29%2520%252B%2520%255Cgamma%2520%255C%252C%2520V%28s%27%29%2520%255Cright%255D" alt="equation" width="509" height="54"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Each sweep propagates value information one step further from the goal. In the GIF above, you can see this: after sweep 0, only the state next to the goal has any value (0.333). By sweep 5, the values have spread across the grid. By sweep 100, they have stabilised.&lt;/p&gt;

&lt;p&gt;The key line in the code is this inner loop:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;prob&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;next_s&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;done&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;unwrapped&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;P&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;]:&lt;/span&gt;
    &lt;span class="n"&gt;action_values&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;prob&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;reward&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;gamma&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;V&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;next_s&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This sums over all possible outcomes of taking action &lt;code&gt;$a$&lt;/code&gt; from state &lt;code&gt;$s$&lt;/code&gt;, weighting each by its transition probability. No randomness, no sampling; it is a deterministic computation over the full model.&lt;/p&gt;

&lt;h3&gt;
  
  
  Q-Learning: Learning by Doing
&lt;/h3&gt;

&lt;p&gt;Q-learning has no access to the transition model. It learns by interacting with the environment, collecting &lt;code&gt;$(s, a, r, s')$&lt;/code&gt; tuples, and updating its Q-table one experience at a time:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;Q&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;lr&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;reward&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;gamma&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Q&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;next_state&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;:])&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;Q&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;state&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This is the &lt;strong&gt;temporal difference (TD) update&lt;/strong&gt;. The term &lt;code&gt;$r + \gamma \max_{a'} Q(s', a')$&lt;/code&gt; is the &lt;strong&gt;TD target&lt;/strong&gt;: what the agent thinks the return should be based on the immediate reward plus the estimated future value. The difference between this target and the current estimate &lt;code&gt;$Q(s, a)$&lt;/code&gt; is the &lt;strong&gt;TD error&lt;/strong&gt;, which drives learning.&lt;/p&gt;

&lt;p&gt;Because Q-learning relies on sampled experience rather than exhaustive computation, it needs many more interactions (10,000 episodes vs 184 sweeps). It also needs an exploration strategy (epsilon-greedy) to ensure it visits enough state-action pairs to build an accurate Q-table. If you have already read our &lt;a href="https://sesen.ai/blog/q-learning-frozen-lake-from-scratch" rel="noopener noreferrer"&gt;Q-learning tutorial&lt;/a&gt;, these mechanics will be familiar.&lt;/p&gt;

&lt;h3&gt;
  
  
  Why Both Reach the Same Answer
&lt;/h3&gt;

&lt;p&gt;This is not a coincidence. Both algorithms are solving the same &lt;strong&gt;Bellman optimality equation&lt;/strong&gt;. Value iteration solves it through repeated full sweeps over the state space. Q-learning solves it through stochastic approximation: each sampled experience nudges the Q-values toward the true solution, one step at a time.&lt;/p&gt;

&lt;p&gt;Given enough sweeps, value iteration converges exactly. Given enough episodes, Q-learning converges asymptotically (with probability 1, under mild conditions on the learning rate and exploration). On FrozenLake, both methods produce policies that agree on 14 of 16 states and achieve the same ~73% success rate.&lt;/p&gt;

&lt;h2&gt;
  
  
  Going Deeper
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Why Not 100%? The Stochasticity Tax
&lt;/h3&gt;

&lt;p&gt;Even the optimal policy only succeeds about 73% of the time on slippery FrozenLake. This is not a bug in the algorithm. The environment is genuinely stochastic: each action has only a 1/3 chance of going in the intended direction, with 1/3 probability of sliding in each perpendicular direction. Some starting positions are simply doomed to fail because all paths to the goal pass near holes, and the ice will occasionally slide you in.&lt;/p&gt;

&lt;h3&gt;
  
  
  Convergence: 184 Sweeps vs 10,000 Episodes
&lt;/h3&gt;

&lt;p&gt;Value iteration converges to the exact solution in 184 sweeps:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fo21x3iw67yuvgm5y6bym.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fo21x3iw67yuvgm5y6bym.webp" alt="Value iteration Bellman error drops exponentially, converging to the threshold in 184 iterations." width="800" height="449"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The Bellman error (maximum change in any state value) decreases exponentially. This is because value iteration is a &lt;strong&gt;contraction mapping&lt;/strong&gt;: each sweep brings V closer to the true V* by a factor of at least &lt;code&gt;$\gamma$&lt;/code&gt;. With &lt;code&gt;$\gamma = 0.95$&lt;/code&gt;, the error shrinks by at least 5% per sweep, guaranteeing convergence.&lt;/p&gt;

&lt;p&gt;Q-learning, by contrast, follows a noisier path:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Ffzja4vqcs01eani7f5af.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Ffzja4vqcs01eani7f5af.webp" alt="Q-learning success rate over 10,000 episodes. The training curve is noisy because the agent is still exploring, but the extracted policy reaches VI-level performance." width="800" height="456"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The rolling average hovers around 40-60% during training because the agent is still exploring (epsilon &amp;gt; 0). But the extracted greedy policy, evaluated after training with epsilon = 0, achieves the same 73% as value iteration.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Model-Based vs Model-Free Tradeoff
&lt;/h3&gt;

&lt;p&gt;This comparison crystallises one of the deepest tradeoffs in reinforcement learning:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fr92ecq0ygikamrniy75h.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fr92ecq0ygikamrniy75h.webp" alt="Head-to-head comparison: VI needs 184 iterations vs Q-learning's 10,000 episodes, both reach 73% success, but VI requires a model while Q-learning does not." width="800" height="282"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Property&lt;/th&gt;
&lt;th&gt;Value Iteration&lt;/th&gt;
&lt;th&gt;Q-Learning&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;Needs transition model?&lt;/td&gt;
&lt;td&gt;Yes (env.P)&lt;/td&gt;
&lt;td&gt;No&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Steps to converge&lt;/td&gt;
&lt;td&gt;184 sweeps&lt;/td&gt;
&lt;td&gt;~10,000 episodes&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Optimality guarantee&lt;/td&gt;
&lt;td&gt;Exact&lt;/td&gt;
&lt;td&gt;Asymptotic&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Works for unknown environments?&lt;/td&gt;
&lt;td&gt;No&lt;/td&gt;
&lt;td&gt;Yes&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Memory&lt;/td&gt;
&lt;td&gt;O(|S|)&lt;/td&gt;
&lt;td&gt;O(|S| × |A|)&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;p&gt;Value iteration is faster and guarantees exact optimality, but it requires something that is rarely available in practice: the full transition model &lt;code&gt;$P(s' \mid s, a)$&lt;/code&gt;. In robotics, game-playing, or any complex real-world task, you almost never have this. That is why model-free methods like Q-learning (and its deep successor, &lt;a href="https://sesen.ai/blog/deep-q-networks-experience-replay-target-networks" rel="noopener noreferrer"&gt;DQN&lt;/a&gt;) dominate modern RL.&lt;/p&gt;

&lt;h3&gt;
  
  
  Hyperparameter Sensitivity
&lt;/h3&gt;

&lt;p&gt;The original code uses a high learning rate (&lt;code&gt;$\alpha = 0.8$&lt;/code&gt;) and fast epsilon decay (&lt;code&gt;$\text{decay\_rate} = 7 \times 10^{-3}$&lt;/code&gt;). This means Q-learning explores aggressively early on and then commits to exploitation within about 1,000 episodes. The high learning rate works here because FrozenLake has a small, discrete state space. For larger problems, you would need to lower &lt;code&gt;$\alpha$&lt;/code&gt; considerably (our &lt;a href="https://sesen.ai/blog/deep-q-networks-experience-replay-target-networks" rel="noopener noreferrer"&gt;DQN post&lt;/a&gt; uses 0.001 with a neural network).&lt;/p&gt;

&lt;p&gt;Value iteration, by contrast, has no learning rate. The discount factor &lt;code&gt;$\gamma = 0.95$&lt;/code&gt; is the only tunable parameter, and it has a clear interpretation: how much to value future rewards relative to immediate ones. Higher gamma means the agent plans further ahead but converges more slowly.&lt;/p&gt;

&lt;h3&gt;
  
  
  When to Use Which
&lt;/h3&gt;

&lt;p&gt;&lt;strong&gt;Use value iteration when:&lt;/strong&gt;&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;You have a complete model of the environment (transition probabilities and rewards)&lt;/li&gt;
&lt;li&gt;The state space is small enough to sweep over exhaustively&lt;/li&gt;
&lt;li&gt;You need guaranteed optimality&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;strong&gt;Use Q-learning when:&lt;/strong&gt;&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;You can only interact with the environment through trial and error&lt;/li&gt;
&lt;li&gt;The model is unknown or too complex to specify&lt;/li&gt;
&lt;li&gt;You are willing to trade computation for generality&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;In practice, most interesting problems fall into the Q-learning camp, which is why model-free methods get so much attention. Not all model-free approaches use value functions, though. Methods like &lt;a href="https://sesen.ai/blog/cross-entropy-method-evolution-style-rl" rel="noopener noreferrer"&gt;the cross-entropy method&lt;/a&gt; and &lt;a href="https://sesen.ai/blog/simulated-annealing-cartpole" rel="noopener noreferrer"&gt;simulated annealing&lt;/a&gt; search policy space directly without ever estimating state values. But understanding value iteration is essential because it reveals the Bellman equation that underlies all value-based RL. As we saw in our &lt;a href="https://sesen.ai/blog/policy-gradients-reinforce-from-scratch" rel="noopener noreferrer"&gt;policy gradient post&lt;/a&gt;, even gradient-based methods ultimately try to maximise the same value function.&lt;/p&gt;

&lt;h2&gt;
  
  
  Deep Dive: The Papers
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Bellman's Foundation
&lt;/h3&gt;

&lt;p&gt;Value iteration traces directly to Richard Bellman's 1957 monograph &lt;em&gt;Dynamic Programming&lt;/em&gt;. Bellman introduced the &lt;strong&gt;principle of optimality&lt;/strong&gt;:&lt;/p&gt;

&lt;blockquote&gt;
&lt;p&gt;"An optimal policy has the property that whatever the initial state and initial decision are, the remaining decisions must constitute an optimal policy with regard to the state resulting from the first decision."&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;This recursive insight leads to the Bellman optimality equation. For the state-value function:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DV%255E%2A%28s%29%2520%253D%2520%255Cmax_a%2520%255Csum_%257Bs%27%257D%2520P%28s%27%2520%255Cmid%2520s%252C%2520a%29%2520%255Cleft%255B%2520R%28s%252C%2520a%252C%2520s%27%29%2520%252B%2520%255Cgamma%2520%255C%252C%2520V%255E%2A%28s%27%29%2520%255Cright%255D" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DV%255E%2A%28s%29%2520%253D%2520%255Cmax_a%2520%255Csum_%257Bs%27%257D%2520P%28s%27%2520%255Cmid%2520s%252C%2520a%29%2520%255Cleft%255B%2520R%28s%252C%2520a%252C%2520s%27%29%2520%252B%2520%255Cgamma%2520%255C%252C%2520V%255E%2A%28s%27%29%2520%255Cright%255D" alt="equation" width="523" height="54"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;And for the action-value function (the one Q-learning estimates):&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DQ%255E%2A%28s%252C%2520a%29%2520%253D%2520%255Csum_%257Bs%27%257D%2520P%28s%27%2520%255Cmid%2520s%252C%2520a%29%2520%255Cleft%255B%2520R%28s%252C%2520a%252C%2520s%27%29%2520%252B%2520%255Cgamma%2520%255Cmax_%257Ba%27%257D%2520Q%255E%2A%28s%27%252C%2520a%27%29%2520%255Cright%255D" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DQ%255E%2A%28s%252C%2520a%29%2520%253D%2520%255Csum_%257Bs%27%257D%2520P%28s%27%2520%255Cmid%2520s%252C%2520a%29%2520%255Cleft%255B%2520R%28s%252C%2520a%252C%2520s%27%29%2520%252B%2520%255Cgamma%2520%255Cmax_%257Ba%27%257D%2520Q%255E%2A%28s%27%252C%2520a%27%29%2520%255Cright%255D" alt="equation" width="583" height="60"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Value iteration simply applies the first equation as an update rule, sweeping over all states until convergence. The convergence is guaranteed because the Bellman operator is a contraction in the sup-norm with coefficient &lt;code&gt;$\gamma &amp;lt; 1$&lt;/code&gt; (proven by Bellman himself and later formalised by Denardo, 1967).&lt;/p&gt;

&lt;h3&gt;
  
  
  Watkins' Q-Learning
&lt;/h3&gt;

&lt;p&gt;Q-learning was introduced by Christopher Watkins in his 1989 PhD thesis at Cambridge, with the convergence proof published in &lt;a href="https://link.springer.com/article/10.1007/BF00992698" rel="noopener noreferrer"&gt;Watkins &amp;amp; Dayan (1992)&lt;/a&gt;. The key insight was that you can learn &lt;code&gt;$Q^*$&lt;/code&gt; directly from experience, without ever knowing the transition model:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DQ%28s%252C%2520a%29%2520%255Cleftarrow%2520Q%28s%252C%2520a%29%2520%252B%2520%255Calpha%2520%255Cleft%255B%2520r%2520%252B%2520%255Cgamma%2520%255Cmax_%257Ba%27%257D%2520Q%28s%27%252C%2520a%27%29%2520-%2520Q%28s%252C%2520a%29%2520%255Cright%255D" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DQ%28s%252C%2520a%29%2520%255Cleftarrow%2520Q%28s%252C%2520a%29%2520%252B%2520%255Calpha%2520%255Cleft%255B%2520r%2520%252B%2520%255Cgamma%2520%255Cmax_%257Ba%27%257D%2520Q%28s%27%252C%2520a%27%29%2520-%2520Q%28s%252C%2520a%29%2520%255Cright%255D" alt="equation" width="551" height="46"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Watkins &amp;amp; Dayan proved that Q-learning converges to &lt;code&gt;$Q^*$&lt;/code&gt; with probability 1, provided:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;All state-action pairs are visited infinitely often&lt;/li&gt;
&lt;li&gt;The learning rate &lt;code&gt;$\alpha$&lt;/code&gt; satisfies: &lt;code&gt;$\sum \alpha_t = \infty$&lt;/code&gt; and &lt;code&gt;$\sum \alpha_t^2 &amp;lt; \infty$&lt;/code&gt;
&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;The first condition is why we need epsilon-greedy exploration. The second is a standard stochastic approximation requirement (Robbins-Monro conditions). In practice, we use a fixed or slowly decaying learning rate and rely on the algorithm converging "well enough" rather than proving formal convergence.&lt;/p&gt;

&lt;h3&gt;
  
  
  The DP-RL Connection
&lt;/h3&gt;

&lt;p&gt;Sutton &amp;amp; Barto's &lt;em&gt;Reinforcement Learning: An Introduction&lt;/em&gt; (2nd ed., 2018) makes the connection explicit in Chapters 4 and 6. Value iteration is presented as a dynamic programming method (Chapter 4), while Q-learning is a temporal difference method (Chapter 6). The book shows that TD methods can be viewed as &lt;strong&gt;sampling-based approximations to DP&lt;/strong&gt;: where DP backs up values using the full distribution over successors, TD methods back up using a single sampled successor.&lt;/p&gt;

&lt;p&gt;This connection runs deep. Every model-free RL algorithm, from &lt;a href="https://sesen.ai/blog/q-learning-frozen-lake-from-scratch" rel="noopener noreferrer"&gt;Q-learning&lt;/a&gt; to &lt;a href="https://sesen.ai/blog/deep-q-networks-experience-replay-target-networks" rel="noopener noreferrer"&gt;DQN&lt;/a&gt; to &lt;a href="https://sesen.ai/blog/policy-gradients-reinforce-from-scratch" rel="noopener noreferrer"&gt;policy gradients&lt;/a&gt;, is implicitly solving a Bellman equation. The difference is in how they approximate the expectation: through tabular sweeps (DP), sampled transitions (TD), or complete episode returns (Monte Carlo).&lt;/p&gt;

&lt;h3&gt;
  
  
  Further Reading
&lt;/h3&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://press.princeton.edu/books/paperback/9780691146683/dynamic-programming" rel="noopener noreferrer"&gt;Bellman, R. (1957)&lt;/a&gt;. &lt;em&gt;Dynamic Programming&lt;/em&gt;. Princeton University Press. The foundational text.&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://link.springer.com/article/10.1007/BF00992698" rel="noopener noreferrer"&gt;Watkins, C.J.C.H. &amp;amp; Dayan, P. (1992)&lt;/a&gt;. "Q-learning". &lt;em&gt;Machine Learning&lt;/em&gt;, 8, 279-292. The convergence proof.&lt;/li&gt;
&lt;li&gt;
&lt;a href="http://incompleteideas.net/book/the-book.html" rel="noopener noreferrer"&gt;Sutton, R.S. &amp;amp; Barto, A.G. (2018)&lt;/a&gt;. &lt;em&gt;Reinforcement Learning: An Introduction&lt;/em&gt;. 2nd edition. Free online. Chapters 4 (DP) and 6 (TD) are directly relevant.&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://onlinelibrary.wiley.com/doi/book/10.1002/9780470316887" rel="noopener noreferrer"&gt;Puterman, M.L. (2014)&lt;/a&gt;. &lt;em&gt;Markov Decision Processes&lt;/em&gt;. The definitive theoretical reference.&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Try It Yourself
&lt;/h2&gt;

&lt;p&gt;The &lt;a href="https://colab.research.google.com/github/zhubarb/sesen_ai_ml_tutorials/blob/main/notebooks/reinforcement-learning/value_iteration_vs_q_learning.ipynb" rel="noopener noreferrer"&gt;interactive notebook&lt;/a&gt; includes exercises:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Non-slippery mode&lt;/strong&gt;: Set &lt;code&gt;is_slippery=False&lt;/code&gt; and compare. Both methods should now achieve ~100% success. How does this change the convergence speed?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;8x8 grid&lt;/strong&gt;: Try &lt;code&gt;FrozenLake8x8-v1&lt;/code&gt;. Value iteration still works perfectly. How does Q-learning cope with the larger state space?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Learned transition model&lt;/strong&gt;: The original code includes a &lt;code&gt;learn_trans_matrix()&lt;/code&gt; function that estimates &lt;code&gt;$P(s' \mid s, a)$&lt;/code&gt; from random play, then runs VI on the learned model. Try this hybrid approach. How many random episodes do you need before the learned model matches the true one?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Discount factor sensitivity&lt;/strong&gt;: Vary &lt;code&gt;$\gamma$&lt;/code&gt; from 0.5 to 0.99 and plot the success rate for both methods. When does a low gamma hurt?&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;Understanding value iteration gives you the theoretical bedrock of RL. Understanding Q-learning gives you the practical tool that works when models are not available. Together, they frame the central tradeoff that drives all of modern reinforcement learning.&lt;/p&gt;

&lt;h2&gt;
  
  
  Interactive Tools
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/q-learning-visualizer" rel="noopener noreferrer"&gt;Q-Learning Visualiser&lt;/a&gt; — Watch value iteration and Q-learning converge on grid worlds in the browser&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Related Posts
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/q-learning-frozen-lake-from-scratch" rel="noopener noreferrer"&gt;Q-Learning on Frozen Lake from Scratch&lt;/a&gt; — Deep dive into tabular Q-learning on the same environment&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/deep-q-networks-experience-replay-target-networks" rel="noopener noreferrer"&gt;Deep Q-Networks: Experience Replay and Target Networks&lt;/a&gt; — Scaling Q-learning with neural networks&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/policy-gradients-reinforce-from-scratch" rel="noopener noreferrer"&gt;Policy Gradients: REINFORCE from Scratch&lt;/a&gt; — The policy-based alternative to value methods&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/cross-entropy-method-evolution-style-rl" rel="noopener noreferrer"&gt;Cross-Entropy Method: Evolution-Style RL&lt;/a&gt; — A gradient-free approach to the same control problems&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Frequently Asked Questions
&lt;/h2&gt;

&lt;h3&gt;
  
  
  What is the difference between value iteration and Q-learning?
&lt;/h3&gt;

&lt;p&gt;Value iteration is a dynamic programming method that requires a complete model of the environment (transition probabilities and rewards) and sweeps through all states systematically. Q-learning is model-free: it learns from experience without knowing the environment dynamics. Both converge to optimal values, but value iteration is faster when a model is available, while Q-learning works when it is not.&lt;/p&gt;

&lt;h3&gt;
  
  
  What is the Bellman equation?
&lt;/h3&gt;

&lt;p&gt;The Bellman equation expresses the value of a state as the immediate reward plus the discounted value of the next state. It is the foundation of both value iteration and Q-learning. Value iteration solves it by iterating the equation across all states until convergence. Q-learning solves it incrementally by updating one state-action pair at a time from experience.&lt;/p&gt;

&lt;h3&gt;
  
  
  When should I use dynamic programming instead of Q-learning?
&lt;/h3&gt;

&lt;p&gt;Use dynamic programming (value iteration, policy iteration) when you have a complete and accurate model of the environment. This is common in games with known rules, inventory management, and operations research. When the model is unknown, too complex, or too large to enumerate, use model-free methods like Q-learning.&lt;/p&gt;

&lt;h3&gt;
  
  
  What is the difference between value iteration and policy iteration?
&lt;/h3&gt;

&lt;p&gt;Value iteration updates the value function using the Bellman optimality equation until convergence, then extracts the policy. Policy iteration alternates between evaluating the current policy exactly and improving it greedily. Policy iteration often converges in fewer iterations but each iteration is more expensive. For small state spaces, both work well.&lt;/p&gt;

&lt;h3&gt;
  
  
  Does value iteration always converge?
&lt;/h3&gt;

&lt;p&gt;Yes, for finite MDPs with a discount factor less than 1. The Bellman operator is a contraction mapping, guaranteeing convergence at a geometric rate. The number of iterations needed depends on the discount factor (higher gamma means slower convergence) and the desired precision. In practice, convergence is usually fast for small to moderate state spaces.&lt;/p&gt;

</description>
      <category>reinforcementlearning</category>
      <category>optimisation</category>
      <category>dynamicprogramming</category>
    </item>
    <item>
      <title>Custom Likelihoods in PyMC: One-Inflated Beta Regression for Loan Repayment</title>
      <dc:creator>Berkan Sesen</dc:creator>
      <pubDate>Fri, 01 May 2026 08:47:52 +0000</pubDate>
      <link>https://dev.to/berkan_sesen/custom-likelihoods-in-pymc-one-inflated-beta-regression-for-loan-repayment-2k5k</link>
      <guid>https://dev.to/berkan_sesen/custom-likelihoods-in-pymc-one-inflated-beta-regression-for-loan-repayment-2k5k</guid>
      <description>&lt;p&gt;When a borrower takes out a personal loan, they might repay every penny, default entirely, or land anywhere in between. The interesting variable is the fraction eventually recovered: a number between 0 and 1 for each loan in the portfolio. Plot the distribution across thousands of loans and it looks like a smooth Beta curve with a tall spike bolted on at the right edge — a mass of borrowers who repaid in full.&lt;/p&gt;

&lt;p&gt;That spike is good news for the lender, but a headache for the modeller. Standard Beta regression handles continuous outcomes on (0, 1), but it cannot produce a point mass at the boundary. Logistic regression predicts a binary paid-or-not label, throwing away the partial repayment information. Neither tool fits the data you actually have.&lt;/p&gt;

&lt;p&gt;In our &lt;a href="https://sesen.ai/blog/hierarchical-bayesian-regression-pymc" rel="noopener noreferrer"&gt;first PyMC post&lt;/a&gt;, we built hierarchical models using built-in distributions. In the &lt;a href="https://sesen.ai/blog/bayesian-survival-analysis-pymc" rel="noopener noreferrer"&gt;second&lt;/a&gt;, we handled non-standard likelihoods with &lt;code&gt;pm.Potential&lt;/code&gt; for right-censored survival data.&lt;/p&gt;

&lt;p&gt;This post takes the final step: writing a piecewise log-likelihood from scratch for a mixture of continuous and discrete components. By the end, you will construct a One-Inflated Beta (OIB) regression in PyMC, hand-code the Beta log-density, and infer how borrower characteristics drive both the probability of full repayment and the expected partial repayment fraction.&lt;/p&gt;

&lt;h2&gt;
  
  
  Let's Build It
&lt;/h2&gt;

&lt;p&gt;Click the badge below to open the full interactive notebook:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://colab.research.google.com/github/zhubarb/sesen_ai_ml_tutorials/blob/main/notebooks/bayesian/one_inflated_beta_regression.ipynb" rel="noopener noreferrer"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fcolab.research.google.com%2Fassets%2Fcolab-badge.svg" alt="Open In Colab" width="117" height="20"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;We will generate synthetic loan data for 2,000 borrowers, fit an OIB regression model, and recover the true data-generating parameters.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fuqqebt5jz7pgqovdi171.gif" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fuqqebt5jz7pgqovdi171.gif" alt="Two-panel animation building up as MCMC draws accumulate. Left panel shows the predicted proportion of fully-repaid loans converging to the observed 60.7%. Right panel shows the posterior predictive Beta component gradually matching the observed partial repayment histogram." width="800" height="267"&gt;&lt;/a&gt;&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;pymc&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;pytensor.tensor&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;pt&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;arviz&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;az&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;matplotlib.pyplot&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;

&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;seed&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# --- Generate synthetic loan data ---
&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;2000&lt;/span&gt;
&lt;span class="n"&gt;credit_score&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;loan_to_value&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;interest_rate&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;income_ratio&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;X&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;column_stack&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;credit_score&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;loan_to_value&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;interest_rate&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;income_ratio&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;span class="n"&gt;feature_names&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;credit_score&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;loan_to_value&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;interest_rate&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;income_ratio&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;

&lt;span class="c1"&gt;# True parameters
&lt;/span&gt;&lt;span class="n"&gt;true_psi&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;array&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.6&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.3&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;    &lt;span class="c1"&gt;# pi coefficients
&lt;/span&gt;&lt;span class="n"&gt;true_delta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;array&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="mf"&gt;0.3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.2&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;   &lt;span class="c1"&gt;# theta coefficients
&lt;/span&gt;&lt;span class="n"&gt;true_phi&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;5.0&lt;/span&gt;                                          &lt;span class="c1"&gt;# Beta precision
&lt;/span&gt;
&lt;span class="c1"&gt;# Per-loan probability of full repayment (logistic link)
&lt;/span&gt;&lt;span class="n"&gt;logit_pi&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;true_psi&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;X&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;true_psi&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:]&lt;/span&gt;
&lt;span class="n"&gt;pi_true&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;logit_pi&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

&lt;span class="c1"&gt;# Per-loan mean partial repayment (logistic link)
&lt;/span&gt;&lt;span class="n"&gt;logit_theta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;true_delta&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;X&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;true_delta&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:]&lt;/span&gt;
&lt;span class="n"&gt;theta_true&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;logit_theta&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

&lt;span class="c1"&gt;# Beta shape parameters from mean-precision
&lt;/span&gt;&lt;span class="n"&gt;alpha_true&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;theta_true&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;true_phi&lt;/span&gt;
&lt;span class="n"&gt;beta_true&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;theta_true&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;true_phi&lt;/span&gt;

&lt;span class="c1"&gt;# Sample from the OIB mixture
&lt;/span&gt;&lt;span class="n"&gt;u&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;uniform&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;repayment&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;where&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;u&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;&lt;/span&gt; &lt;span class="n"&gt;pi_true&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;alpha_true&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta_true&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

&lt;span class="n"&gt;n_full&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;repayment&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Fully repaid: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;n_full&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;/&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; (&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;n_full&lt;/span&gt;&lt;span class="o"&gt;/&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="o"&gt;%&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;)&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Partial repayment: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;n_full&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;/&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; (&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;n_full&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;/&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="o"&gt;%&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;)&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Frjpli5x77vrceay312w1.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Frjpli5x77vrceay312w1.webp" alt="Histogram of loan repayment fractions showing a tall spike at 1.0 for 1,214 fully repaid loans and a smooth Beta-shaped distribution for 786 partial repayments between 0 and 1." width="800" height="394"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Of 2,000 loans, 1,214 (60.7%) are fully repaid and 786 (39.3%) show partial repayment. The histogram immediately reveals the two populations: a tall spike at 1.0 and a continuous spread below it. No single standard distribution can capture both. Now let's build the OIB model.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Split observations by type
&lt;/span&gt;&lt;span class="n"&gt;full_idx&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;where&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;repayment&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;)[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;span class="n"&gt;partial_idx&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;where&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;repayment&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;&lt;/span&gt; &lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;)[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;span class="n"&gt;partial_values&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;repayment&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;partial_idx&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;

&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Model&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;oib_model&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="c1"&gt;# Pi sub-model: probability of full repayment (logistic link)
&lt;/span&gt;    &lt;span class="n"&gt;psi_intercept&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;psi_intercept&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;psi_coeffs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;psi_coeffs&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;logit_pi&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;psi_intercept&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;dot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;psi_coeffs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;pi&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Deterministic&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;pi&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;invlogit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;logit_pi&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

    &lt;span class="c1"&gt;# Theta sub-model: mean of partial repayment Beta (logistic link)
&lt;/span&gt;    &lt;span class="n"&gt;delta_intercept&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;delta_intercept&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;delta_coeffs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;delta_coeffs&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;logit_theta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;delta_intercept&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;dot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;delta_coeffs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;theta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Deterministic&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;theta&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;invlogit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;logit_theta&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

    &lt;span class="c1"&gt;# Phi: Beta precision (shared across all loans)
&lt;/span&gt;    &lt;span class="n"&gt;phi&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Gamma&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;phi&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Convert mean-precision to standard Beta parameters
&lt;/span&gt;    &lt;span class="n"&gt;a&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;theta&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;phi&lt;/span&gt;
    &lt;span class="n"&gt;b&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;theta&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;phi&lt;/span&gt;

    &lt;span class="c1"&gt;# Expected repayment: E[Y] = pi + (1 - pi) * theta
&lt;/span&gt;    &lt;span class="n"&gt;E_f&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Deterministic&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;E_f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pi&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;pi&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;theta&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# --- Piecewise log-likelihood via pm.Potential ---
&lt;/span&gt;    &lt;span class="c1"&gt;# Fully repaid loans: log(pi_i)
&lt;/span&gt;    &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Potential&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;ll_full&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pi&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;full_idx&lt;/span&gt;&lt;span class="p"&gt;])))&lt;/span&gt;

    &lt;span class="c1"&gt;# Partial repayments: log(1 - pi_i) + log Beta(y_i | a_i, b_i)
&lt;/span&gt;    &lt;span class="n"&gt;pa&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pb&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;partial_idx&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;partial_idx&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="n"&gt;beta_logp&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;gammaln&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pa&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;pb&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;gammaln&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pa&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;gammaln&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pb&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
                 &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pa&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;partial_values&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
                 &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pb&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;partial_values&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
    &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Potential&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;ll_partial&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;pi&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;partial_idx&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;beta_logp&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;oib_model&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;trace&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;draws&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tune&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;3000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;chains&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                      &lt;span class="n"&gt;target_accept&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.95&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_seed&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                      &lt;span class="n"&gt;init&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;jitter+adapt_diag&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;az&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;summary&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;trace&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;var_names&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;psi_intercept&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;psi_coeffs&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                                    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;delta_intercept&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;delta_coeffs&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;phi&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F8wtdk4i3mnhuap2jacol.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F8wtdk4i3mnhuap2jacol.webp" alt="Trace plots for the OIB model showing posterior distributions and MCMC chains for psi_intercept, psi_coeffs, delta_intercept, delta_coeffs, and phi, all exhibiting good mixing and convergence with zero divergences." width="800" height="702"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The trace plots show healthy chains: zero divergences, good mixing across all four chains, and unimodal posteriors centred near the true parameter values. Sampling 4,000 draws per chain with the Potential-based likelihood took about 6 seconds.&lt;/p&gt;

&lt;p&gt;You just fitted a custom Bayesian mixture model with 11 free parameters. Now let's understand how each piece works.&lt;/p&gt;

&lt;h2&gt;
  
  
  What Just Happened?
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Two populations, one model
&lt;/h3&gt;

&lt;p&gt;Our data contains two distinct groups. Some borrowers repay their loan in full (repayment fraction = 1.0), and others repay partially (0 &amp;lt; fraction &amp;lt; 1). The OIB model treats this as a mixture: with probability &lt;code&gt;$\pi_i$&lt;/code&gt; the outcome is exactly 1, and with probability &lt;code&gt;$1 - \pi_i$&lt;/code&gt; it follows a Beta distribution.&lt;/p&gt;

&lt;p&gt;Both &lt;code&gt;$\pi_i$&lt;/code&gt; and the Beta mean &lt;code&gt;$\theta_i$&lt;/code&gt; vary across borrowers. A high credit score might increase both the chance of full repayment and the expected partial repayment. The model captures these relationships through separate linear predictors with logistic links, ensuring both quantities stay between 0 and 1.&lt;/p&gt;

&lt;h3&gt;
  
  
  The piecewise log-likelihood
&lt;/h3&gt;

&lt;p&gt;The OIB density is a mixture of a point mass and a continuous distribution. For observation &lt;code&gt;$y_i$&lt;/code&gt;:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Dp%28y_i%29%2520%253D%2520%255Cbegin%257Bcases%257D%2520%255Cpi_i%2520%2526%2520%255Ctext%257Bif%2520%257D%2520y_i%2520%253D%25201%2520%255C%255C%2520%281%2520-%2520%255Cpi_i%29%2520%255Ccdot%2520f_%257B%255Ctext%257BBeta%257D%257D%28y_i%2520%255Cmid%2520%255Calpha_i%252C%2520%255Cbeta_i%29%2520%2526%2520%255Ctext%257Bif%2520%257D%25200%2520%253C%2520y_i%2520%253C%25201%2520%255Cend%257Bcases%257D" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Dp%28y_i%29%2520%253D%2520%255Cbegin%257Bcases%257D%2520%255Cpi_i%2520%2526%2520%255Ctext%257Bif%2520%257D%2520y_i%2520%253D%25201%2520%255C%255C%2520%281%2520-%2520%255Cpi_i%29%2520%255Ccdot%2520f_%257B%255Ctext%257BBeta%257D%257D%28y_i%2520%255Cmid%2520%255Calpha_i%252C%2520%255Cbeta_i%29%2520%2526%2520%255Ctext%257Bif%2520%257D%25200%2520%253C%2520y_i%2520%253C%25201%2520%255Cend%257Bcases%257D" alt="equation" width="521" height="75"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Taking logs:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Clog%2520p%28y_i%29%2520%253D%2520%255Cbegin%257Bcases%257D%2520%255Clog%2520%255Cpi_i%2520%2526%2520%255Ctext%257Bif%2520%257D%2520y_i%2520%253D%25201%2520%255C%255C%2520%255Clog%281%2520-%2520%255Cpi_i%29%2520%252B%2520%255Clog%2520f_%257B%255Ctext%257BBeta%257D%257D%28y_i%2520%255Cmid%2520%255Calpha_i%252C%2520%255Cbeta_i%29%2520%2526%2520%255Ctext%257Bif%2520%257D%25200%2520%253C%2520y_i%2520%253C%25201%2520%255Cend%257Bcases%257D" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Clog%2520p%28y_i%29%2520%253D%2520%255Cbegin%257Bcases%257D%2520%255Clog%2520%255Cpi_i%2520%2526%2520%255Ctext%257Bif%2520%257D%2520y_i%2520%253D%25201%2520%255C%255C%2520%255Clog%281%2520-%2520%255Cpi_i%29%2520%252B%2520%255Clog%2520f_%257B%255Ctext%257BBeta%257D%257D%28y_i%2520%255Cmid%2520%255Calpha_i%252C%2520%255Cbeta_i%29%2520%2526%2520%255Ctext%257Bif%2520%257D%25200%2520%253C%2520y_i%2520%253C%25201%2520%255Cend%257Bcases%257D" alt="equation" width="635" height="75"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The addition in the second branch is critical: it corresponds to multiplying the mixing weight &lt;code&gt;$(1 - \pi_i)$&lt;/code&gt; by the Beta density in probability space. A common mistake is to write multiplication of two log quantities (i.e. &lt;code&gt;log(1-pi) * log(Beta(...))&lt;/code&gt;) instead of addition. That would have no probabilistic interpretation.&lt;/p&gt;

&lt;p&gt;We implement this by splitting observations into two groups and adding each group's log-likelihood as a separate &lt;code&gt;pm.Potential&lt;/code&gt;:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Fully repaid: sum of log(pi_i) over fully-repaid loans
&lt;/span&gt;&lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Potential&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;ll_full&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pi&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;full_idx&lt;/span&gt;&lt;span class="p"&gt;])))&lt;/span&gt;

&lt;span class="c1"&gt;# Partial: sum of log(1 - pi_i) + Beta_logpdf(y_i) over partial loans
&lt;/span&gt;&lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Potential&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;ll_partial&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;pi&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;partial_idx&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;beta_logp&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This pattern should feel familiar from &lt;a href="https://sesen.ai/blog/bayesian-survival-analysis-pymc" rel="noopener noreferrer"&gt;Post 21&lt;/a&gt;, where we used &lt;code&gt;pm.Potential&lt;/code&gt; to handle right-censored observations. The principle is the same: when your likelihood has distinct branches for different observation types, split them into separate Potential terms.&lt;/p&gt;

&lt;h3&gt;
  
  
  Hand-coding the Beta log-density
&lt;/h3&gt;

&lt;p&gt;Rather than relying on &lt;code&gt;pm.logp(pm.Beta.dist(...), value)&lt;/code&gt;, we compute the Beta log-density directly using the gamma function:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;beta_logp&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;gammaln&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;gammaln&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;gammaln&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
             &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;b&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This follows from the Beta density formula &lt;code&gt;$f(y \mid \alpha, \beta) = \frac{\Gamma(\alpha + \beta)}{\Gamma(\alpha)\Gamma(\beta)} y^{\alpha-1}(1-y)^{\beta-1}$&lt;/code&gt;. Writing it out explicitly has two advantages: you can see exactly what the sampler is differentiating through, and you avoid potential issues with PyMC's internal distribution objects when used inside Potential expressions.&lt;/p&gt;

&lt;h3&gt;
  
  
  The model structure
&lt;/h3&gt;

&lt;p&gt;The model has three sub-components connected by link functions:&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Pi sub-model&lt;/strong&gt; controls which mixture component generates each observation. A logistic link maps the linear predictor &lt;code&gt;$\psi_0 + \psi_1 x_{\text{credit}} + \psi_2 x_{\text{ltv}} + \psi_3 x_{\text{rate}} + \psi_4 x_{\text{income}}$&lt;/code&gt; to a probability. Positive &lt;code&gt;$\psi_1$&lt;/code&gt; means higher credit scores increase the chance of full repayment.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Theta sub-model&lt;/strong&gt; sets the mean of the Beta distribution for partial repayments, also through a logistic link with its own coefficients &lt;code&gt;$\delta_0, \ldots, \delta_4$&lt;/code&gt;. This captures a subtlety that pure classification misses: among borrowers who do not fully repay, some covariates still push the partial fraction higher.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Phi&lt;/strong&gt; is a single shared precision parameter for the Beta component. Higher phi means less variance in partial repayments. It uses a &lt;code&gt;$\text{Gamma}(2, 0.5)$&lt;/code&gt; prior with mean 4, which favours moderate precision values.&lt;/p&gt;

&lt;h3&gt;
  
  
  Checking the fit
&lt;/h3&gt;

&lt;p&gt;Let's compare the estimated coefficients to the true values we used to generate the data.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;summary&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;az&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;summary&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;trace&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;var_names&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;psi_intercept&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;psi_coeffs&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                                        &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;delta_intercept&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;delta_coeffs&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;phi&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;

&lt;span class="n"&gt;true_vals&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;concatenate&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;true_psi&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;true_delta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;true_phi&lt;/span&gt;&lt;span class="p"&gt;]])&lt;/span&gt;
&lt;span class="n"&gt;param_names&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;psi_intercept&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;psi_coeffs[&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;]&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;)]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt;
               &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;delta_intercept&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;delta_coeffs[&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;]&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;)]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;phi&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;

&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;subplots&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;y_pos&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;arange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;param_names&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;means&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;summary&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;mean&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt;
&lt;span class="n"&gt;hdi_low&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;summary&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;hdi_3%&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt;
&lt;span class="n"&gt;hdi_high&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;summary&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;hdi_97%&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt;

&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;errorbar&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;means&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_pos&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;xerr&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;means&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;hdi_low&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hdi_high&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;means&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
            &lt;span class="n"&gt;fmt&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;o&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;steelblue&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;capsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Posterior (94% HDI)&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;scatter&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;true_vals&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_pos&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;marker&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;x&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;crimson&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;80&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
           &lt;span class="n"&gt;zorder&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;True value&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_yticks&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y_pos&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_yticklabels&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;param_names&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;axvline&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;gray&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;linestyle&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;--&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_xlabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Value&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;legend&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;loc&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;lower right&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_title&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Parameter Recovery: Posterior vs True Values&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tight_layout&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fkwm5v25kdcuv0vf13111.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fkwm5v25kdcuv0vf13111.webp" alt="Forest plot showing posterior means with 94% HDI intervals for all 11 model parameters alongside the true values used for data generation, demonstrating accurate parameter recovery across both the pi and theta sub-models." width="800" height="596"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Every true value falls within its 94% highest density interval. The model correctly identifies that credit score has the strongest positive effect on full repayment (psi_coeffs[0] = 0.85, true: 0.8), while loan-to-value ratio is the strongest negative predictor (psi_coeffs[1] = -0.58, true: -0.6). The precision parameter phi is recovered at 5.47 (true: 5.0), and the effective sample sizes all exceed 2,500.&lt;/p&gt;

&lt;h3&gt;
  
  
  Posterior predictive check
&lt;/h3&gt;

&lt;p&gt;The ultimate test: can the model reproduce the observed data distribution, including the spike at 1.0? Since we used &lt;code&gt;pm.Potential&lt;/code&gt; rather than an observed distribution, we generate predictive samples manually from the posterior:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Extract posterior samples
&lt;/span&gt;&lt;span class="n"&gt;psi_int_post&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;trace&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;posterior&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;psi_intercept&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;flatten&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;psi_coeff_post&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;trace&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;posterior&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;psi_coeffs&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;delta_int_post&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;trace&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;posterior&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;delta_intercept&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;flatten&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;delta_coeff_post&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;trace&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;posterior&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;delta_coeffs&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;phi_post&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;trace&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;posterior&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;phi&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;flatten&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

&lt;span class="n"&gt;rng&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;default_rng&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;n_draws&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;500&lt;/span&gt;
&lt;span class="n"&gt;ppc_samples&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;n_draws&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_draws&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;lp&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;psi_int_post&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;X&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;psi_coeff_post&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="n"&gt;pi_i&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;lp&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
    &lt;span class="n"&gt;lt&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;delta_int_post&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;X&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;delta_coeff_post&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="n"&gt;theta_i&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;lt&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
    &lt;span class="n"&gt;a_i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;b_i&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;theta_i&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;phi_post&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;theta_i&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;phi_post&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="n"&gt;u_i&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;rng&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;uniform&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;ppc_samples&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;where&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;u_i&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;&lt;/span&gt; &lt;span class="n"&gt;pi_i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;rng&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;a_i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;b_i&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;subplots&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;hist&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;repayment&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;bins&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;50&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;density&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;steelblue&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Observed&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;edgecolor&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;white&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;hist&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;ppc_samples&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;flatten&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="n"&gt;bins&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;50&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;density&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;coral&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Posterior predictive&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;edgecolor&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;white&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_xlabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Repayment Fraction&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_ylabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Density&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;legend&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_title&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Posterior Predictive Check&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tight_layout&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Falj4mzje5lrej3mucg09.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Falj4mzje5lrej3mucg09.webp" alt="Two-panel posterior predictive check. Left: observed vs predicted proportion of fully-repaid loans (both around 60.7%). Right: observed vs predicted density of partial repayments, showing the Beta component accurately captures the continuous distribution shape." width="800" height="277"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The posterior predictive distribution matches both the spike at 1.0 and the shape of the partial repayment component. This is something neither pure Beta regression nor logistic regression can achieve.&lt;/p&gt;

&lt;h2&gt;
  
  
  Going Deeper
&lt;/h2&gt;

&lt;h3&gt;
  
  
  The mean-precision parameterisation
&lt;/h3&gt;

&lt;p&gt;The standard Beta distribution uses shape parameters &lt;code&gt;$\alpha$&lt;/code&gt; and &lt;code&gt;$\beta$&lt;/code&gt;, but these are difficult to interpret. A borrower with &lt;code&gt;$\alpha = 2.8$&lt;/code&gt; and &lt;code&gt;$\beta = 2.1$&lt;/code&gt; tells you almost nothing at a glance. The mean-precision reparameterisation solves this:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Cmu%2520%253D%2520%255Cfrac%257B%255Calpha%257D%257B%255Calpha%2520%252B%2520%255Cbeta%257D%252C%2520%255Cquad%2520%255Cphi%2520%253D%2520%255Calpha%2520%252B%2520%255Cbeta" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Cmu%2520%253D%2520%255Cfrac%257B%255Calpha%257D%257B%255Calpha%2520%252B%2520%255Cbeta%257D%252C%2520%255Cquad%2520%255Cphi%2520%253D%2520%255Calpha%2520%252B%2520%255Cbeta" alt="equation" width="257" height="50"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Now &lt;code&gt;$\mu$&lt;/code&gt; is the mean of the distribution (the expected partial repayment fraction) and &lt;code&gt;$\phi$&lt;/code&gt; is the precision (higher means less spread). The inverse mapping recovers the standard parameters:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Calpha%2520%253D%2520%255Cmu%2520%255Cphi%252C%2520%255Cquad%2520%255Cbeta%2520%253D%2520%281%2520-%2520%255Cmu%29%2520%255Cphi" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Calpha%2520%253D%2520%255Cmu%2520%255Cphi%252C%2520%255Cquad%2520%255Cbeta%2520%253D%2520%281%2520-%2520%255Cmu%29%2520%255Cphi" alt="equation" width="251" height="25"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;In our model, &lt;code&gt;$\mu$&lt;/code&gt; is called &lt;code&gt;$\theta$&lt;/code&gt; and depends on covariates through a logistic link. The precision &lt;code&gt;$\phi$&lt;/code&gt; is shared across all observations, which assumes that the variance of partial repayments (given the mean) is the same for all borrowers. This is a simplification; a fully heteroscedastic model would give &lt;code&gt;$\phi$&lt;/code&gt; its own linear predictor.&lt;/p&gt;

&lt;h3&gt;
  
  
  Why logistic links?
&lt;/h3&gt;

&lt;p&gt;Both &lt;code&gt;$\pi$&lt;/code&gt; (probability of full repayment) and &lt;code&gt;$\theta$&lt;/code&gt; (mean of the Beta) must live in (0, 1). The logistic function &lt;code&gt;$\sigma(x) = 1 / (1 + e^{-x})$&lt;/code&gt; maps any real-valued linear predictor to this interval. This is the same link function used in &lt;a href="https://sesen.ai/blog/from-mle-to-bayesian-inference" rel="noopener noreferrer"&gt;logistic regression and Bayesian classification&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;The priors reflect the link: &lt;code&gt;$\text{Normal}(0, 5)$&lt;/code&gt; on the intercepts allows the baseline probability to range widely, while &lt;code&gt;$\text{Normal}(0, 1)$&lt;/code&gt; on the slope coefficients gently regularises each covariate's effect. On the logistic scale, a coefficient of 1.0 roughly doubles the odds, so a &lt;code&gt;$\text{Normal}(0, 1)$&lt;/code&gt; prior is mildly informative.&lt;/p&gt;

&lt;h3&gt;
  
  
  The expected value formula
&lt;/h3&gt;

&lt;p&gt;The overall expected repayment for borrower &lt;code&gt;$i$&lt;/code&gt; combines both components:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Cmathbb%257BE%257D%255BY_i%255D%2520%253D%2520%255Cpi_i%2520%255Ccdot%25201%2520%252B%2520%281%2520-%2520%255Cpi_i%29%2520%255Ccdot%2520%255Ctheta_i%2520%253D%2520%255Cpi_i%2520%252B%2520%281%2520-%2520%255Cpi_i%29%255Ctheta_i" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Cmathbb%257BE%257D%255BY_i%255D%2520%253D%2520%255Cpi_i%2520%255Ccdot%25201%2520%252B%2520%281%2520-%2520%255Cpi_i%29%2520%255Ccdot%2520%255Ctheta_i%2520%253D%2520%255Cpi_i%2520%252B%2520%281%2520-%2520%255Cpi_i%29%255Ctheta_i" alt="equation" width="467" height="25"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;This is the &lt;code&gt;E_f&lt;/code&gt; deterministic in our model. It allows you to rank borrowers by expected repayment even when their risk profiles differ in how they fail: one borrower might have a high chance of full repayment but low partial repayment if they default, while another has a moderate chance of full repayment but high partial recovery.&lt;/p&gt;

&lt;h3&gt;
  
  
  Why pm.Potential and not pm.CustomDist?
&lt;/h3&gt;

&lt;p&gt;PyMC offers two ways to implement custom likelihoods. &lt;code&gt;pm.CustomDist&lt;/code&gt; lets you define a distribution from its &lt;code&gt;logp&lt;/code&gt; function, which would look like this for OIB:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;oib_logp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;value&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pi&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;switch&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
        &lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;eq&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;value&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
        &lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pi&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
        &lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;pi&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;logp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Beta&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;dist&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;value&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This is elegant but fragile. The &lt;code&gt;pt.switch&lt;/code&gt; operator evaluates both branches for every observation during automatic differentiation.&lt;/p&gt;

&lt;p&gt;When &lt;code&gt;value = 1.0&lt;/code&gt;, the Beta branch computes &lt;code&gt;pm.logp(Beta, 1.0)&lt;/code&gt;, which returns negative infinity (since the Beta density is zero at boundaries for &lt;code&gt;$\beta &amp;gt; 1$&lt;/code&gt;). Even though the switch selects the other branch, the gradient through the infinite branch corrupts the NUTS sampler. The result: 100% divergence rate.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F9f281ukhszqz0zw0hqt2.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F9f281ukhszqz0zw0hqt2.webp" alt="Diagram of the One-Inflated Beta model showing covariates feeding into two parallel sub-models: a logistic regression for pi (full repayment probability) and a logistic regression for theta (partial repayment mean), which combine through a piecewise likelihood with a shared precision parameter phi." width="800" height="446"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The &lt;code&gt;pm.Potential&lt;/code&gt; approach avoids this entirely. By pre-splitting observations into fully-repaid and partial groups, the Beta density is never evaluated at the boundary. This is the same pattern we used for &lt;a href="https://sesen.ai/blog/bayesian-survival-analysis-pymc" rel="noopener noreferrer"&gt;censored data in survival analysis&lt;/a&gt;: separate the observation types, compute each group's log-likelihood independently, and add them as Potential terms.&lt;/p&gt;

&lt;p&gt;The trade-off is that &lt;code&gt;pm.Potential&lt;/code&gt; does not enable &lt;code&gt;pm.sample_posterior_predictive&lt;/code&gt; out of the box (you need to write manual prediction code, as we did). For many production workflows, that is a minor inconvenience compared to the reliability gain.&lt;/p&gt;

&lt;h3&gt;
  
  
  Sampling considerations
&lt;/h3&gt;

&lt;p&gt;Our sampling configuration follows the original code that inspired this tutorial:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;3,000 tuning steps&lt;/strong&gt; with 1,000 posterior draws per chain. The long warm-up helps the NUTS sampler adapt its step size to the geometry of the piecewise likelihood.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;4 chains&lt;/strong&gt; for convergence diagnostics. With &lt;code&gt;$\hat{R}$&lt;/code&gt; and effective sample size, four chains provide reliable evidence that the sampler has explored the full posterior.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;&lt;code&gt;target_accept=0.95&lt;/code&gt;&lt;/strong&gt; raises the acceptance threshold from the default 0.8, which reduces divergences in models with sharp likelihood boundaries.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;&lt;code&gt;init='jitter+adapt_diag'&lt;/code&gt;&lt;/strong&gt; initialises each chain near the prior mean with small random perturbations. A practical note from the original code: if covariates have very different scales (e.g., one ranges from 0 to 1 while another ranges from 0 to 200), the default jitter of roughly &lt;code&gt;$\pm 1$&lt;/code&gt; can push initial coefficient values far from reasonable territory. Standardising covariates beforehand, as we did, avoids this.&lt;/li&gt;
&lt;/ul&gt;

&lt;h3&gt;
  
  
  When to use something else
&lt;/h3&gt;

&lt;p&gt;The OIB model assumes that exactly-one observations arise from a fundamentally different process than partial observations. If instead you have data with a spike at zero (e.g., insurance claims where most customers file nothing), you want a &lt;strong&gt;zero-inflated&lt;/strong&gt; model. If you have spikes at both boundaries, you need a &lt;strong&gt;zero-and-one-inflated Beta&lt;/strong&gt; (ZOIB).&lt;/p&gt;

&lt;p&gt;For data with no boundary spikes at all, standard &lt;a href="https://sesen.ai/blog/maximum-likelihood-estimation-from-scratch" rel="noopener noreferrer"&gt;Beta regression&lt;/a&gt; (via MLE or Bayesian inference) is simpler and sufficient. The extra complexity of the OIB mixture is only justified when the data genuinely contains a discrete mass at the boundary.&lt;/p&gt;

&lt;h2&gt;
  
  
  Where This Comes From
&lt;/h2&gt;

&lt;p&gt;The OIB model sits at the intersection of two lines of research: Beta regression for bounded continuous data, and inflated distributions for boundary spikes.&lt;/p&gt;

&lt;h3&gt;
  
  
  Beta regression: Ferrari and Cribari-Neto (2004)
&lt;/h3&gt;

&lt;p&gt;The foundation is the Beta regression model introduced by Silvia Ferrari and Francisco Cribari-Neto in their 2004 paper "Beta Regression for Modelling Rates and Proportions" (Journal of Applied Statistics, 27(7), 799-815). They observed that rates, proportions, and fractions appear everywhere in applied statistics, yet researchers typically transform them (logit, arcsine) and apply linear regression. This is problematic because the transformation distorts the error structure and complicates interpretation.&lt;/p&gt;

&lt;p&gt;Their key insight was to model the response directly as Beta-distributed, using the mean-precision parameterisation we adopted:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Df%28y%253B%2520%255Cmu%252C%2520%255Cphi%29%2520%253D%2520%255Cfrac%257B%255CGamma%28%255Cphi%29%257D%257B%255CGamma%28%255Cmu%255Cphi%29%255C%252C%255CGamma%28%281-%255Cmu%29%255Cphi%29%257D%255C%252C%2520y%255E%257B%255Cmu%255Cphi%2520-%25201%257D%281-y%29%255E%257B%281-%255Cmu%29%255Cphi%2520-%25201%257D" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Df%28y%253B%2520%255Cmu%252C%2520%255Cphi%29%2520%253D%2520%255Cfrac%257B%255CGamma%28%255Cphi%29%257D%257B%255CGamma%28%255Cmu%255Cphi%29%255C%252C%255CGamma%28%281-%255Cmu%29%255Cphi%29%257D%255C%252C%2520y%255E%257B%255Cmu%255Cphi%2520-%25201%257D%281-y%29%255E%257B%281-%255Cmu%29%255Cphi%2520-%25201%257D" alt="equation" width="542" height="59"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;where &lt;code&gt;$0 &amp;lt; y &amp;lt; 1$&lt;/code&gt;, &lt;code&gt;$0 &amp;lt; \mu &amp;lt; 1$&lt;/code&gt; is the mean, and &lt;code&gt;$\phi &amp;gt; 0$&lt;/code&gt; is the precision. Ferrari and Cribari-Neto showed that this is a natural exponential family model when parameterised through &lt;code&gt;$\mu$&lt;/code&gt;, and proposed a logit link for the mean:&lt;/p&gt;

&lt;blockquote&gt;
&lt;p&gt;"The proposed model is useful for situations where the variable of interest is continuous and restricted to the interval (0, 1). [...] A convenient parameterisation of the beta density in terms of the mean and a precision parameter is used."&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;Their framework supports maximum likelihood estimation, but the Bayesian extension (which we use) adds uncertainty quantification and regularisation through priors. The connection to &lt;a href="https://sesen.ai/blog/maximum-likelihood-estimation-from-scratch" rel="noopener noreferrer"&gt;MLE&lt;/a&gt; is direct: the posterior mode of our model with flat priors equals the MLE of Ferrari and Cribari-Neto's model.&lt;/p&gt;

&lt;h3&gt;
  
  
  Inflated models: Ospina and Ferrari (2010)
&lt;/h3&gt;

&lt;p&gt;The standard Beta has support on the open interval (0, 1), so it cannot assign positive probability to the boundaries 0 or 1. Raydonal Ospina and Silvia Ferrari addressed this in "Inflated Beta Distributions" (Statistical Papers, 51(1), 111-126, 2010). They defined a class of mixed continuous-discrete distributions:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Ctext%257BOIB%257D%28y%2520%255Cmid%2520%255Cpi%252C%2520%255Cmu%252C%2520%255Cphi%29%2520%253D%2520%255Cbegin%257Bcases%257D%2520%255Cpi%2520%2526%2520%255Ctext%257Bif%2520%257D%2520y%2520%253D%25201%2520%255C%255C%2520%281%2520-%2520%255Cpi%29%2520%255Ccdot%2520f_%257B%255Ctext%257BBeta%257D%257D%28y%2520%255Cmid%2520%255Cmu%252C%2520%255Cphi%29%2520%2526%2520%255Ctext%257Bif%2520%257D%25200%2520%253C%2520y%2520%253C%25201%2520%255Cend%257Bcases%257D" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Ctext%257BOIB%257D%28y%2520%255Cmid%2520%255Cpi%252C%2520%255Cmu%252C%2520%255Cphi%29%2520%253D%2520%255Cbegin%257Bcases%257D%2520%255Cpi%2520%2526%2520%255Ctext%257Bif%2520%257D%2520y%2520%253D%25201%2520%255C%255C%2520%281%2520-%2520%255Cpi%29%2520%255Ccdot%2520f_%257B%255Ctext%257BBeta%257D%257D%28y%2520%255Cmid%2520%255Cmu%252C%2520%255Cphi%29%2520%2526%2520%255Ctext%257Bif%2520%257D%25200%2520%253C%2520y%2520%253C%25201%2520%255Cend%257Bcases%257D" alt="equation" width="599" height="75"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;This is exactly the piecewise density we implemented with &lt;code&gt;pm.Potential&lt;/code&gt;. The parameter &lt;code&gt;$\pi$&lt;/code&gt; controls the inflation: the probability of observing the boundary value. Ospina and Ferrari also developed zero-inflated and zero-and-one-inflated variants for different boundary patterns.&lt;/p&gt;

&lt;blockquote&gt;
&lt;p&gt;"In many practical situations, the variable of interest is continuous in the open standard unit interval but may also assume the extreme values zero and/or one with positive probabilities. [...] We introduce a class of inflated beta distributions."&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;Their work established the theoretical properties (moments, maximum likelihood estimation, score functions) that underpin our Bayesian implementation.&lt;/p&gt;

&lt;h3&gt;
  
  
  From MLE to MCMC
&lt;/h3&gt;

&lt;p&gt;The original MLE approach estimates &lt;code&gt;$\pi$&lt;/code&gt;, &lt;code&gt;$\mu$&lt;/code&gt;, and &lt;code&gt;$\phi$&lt;/code&gt; by maximising the log-likelihood. The Bayesian version replaces optimisation with &lt;a href="https://sesen.ai/blog/mcmc-metropolis-hastings-island-hopping-guide" rel="noopener noreferrer"&gt;MCMC sampling&lt;/a&gt;, yielding full posterior distributions rather than point estimates. This is particularly valuable for the OIB model because the piecewise likelihood creates a posterior geometry that point estimates cannot capture: the uncertainty in &lt;code&gt;$\pi$&lt;/code&gt; and &lt;code&gt;$\theta$&lt;/code&gt; is correlated, and the posterior for &lt;code&gt;$\phi$&lt;/code&gt; is often skewed.&lt;/p&gt;

&lt;p&gt;Where Ferrari and Cribari-Neto derived score functions by hand, we supply the log-density components to PyMC and let the NUTS sampler handle the rest. The automatic differentiation in PyTensor computes gradients through the gammaln and log operations, enabling efficient Hamiltonian Monte Carlo.&lt;/p&gt;

&lt;h3&gt;
  
  
  Algorithm summary
&lt;/h3&gt;

&lt;p&gt;The complete OIB regression procedure:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;For each observation &lt;code&gt;$i$&lt;/code&gt;, compute &lt;code&gt;$\pi_i = \sigma(\psi_0 + \mathbf{x}_i^\top \boldsymbol{\psi})$&lt;/code&gt; (full repayment probability)&lt;/li&gt;
&lt;li&gt;Compute &lt;code&gt;$\theta_i = \sigma(\delta_0 + \mathbf{x}_i^\top \boldsymbol{\delta})$&lt;/code&gt; (partial repayment mean)&lt;/li&gt;
&lt;li&gt;Compute Beta shape parameters: &lt;code&gt;$\alpha_i = \theta_i \phi$&lt;/code&gt;, &lt;code&gt;$\beta_i = (1 - \theta_i) \phi$&lt;/code&gt;
&lt;/li&gt;
&lt;li&gt;Evaluate the piecewise log-likelihood: &lt;code&gt;$\log \pi_i$&lt;/code&gt; if &lt;code&gt;$y_i = 1$&lt;/code&gt;, else &lt;code&gt;$\log(1 - \pi_i) + \log \text{Beta}(y_i \mid \alpha_i, \beta_i)$&lt;/code&gt;
&lt;/li&gt;
&lt;li&gt;Sum across all observations and sample the posterior via NUTS&lt;/li&gt;
&lt;/ol&gt;

&lt;h3&gt;
  
  
  Further reading
&lt;/h3&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;The Beta regression paper:&lt;/strong&gt; Ferrari, S. and Cribari-Neto, F. (2004). "Beta Regression for Modelling Rates and Proportions." &lt;em&gt;Journal of Applied Statistics&lt;/em&gt;, 27(7), 799-815.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Inflated distributions:&lt;/strong&gt; Ospina, R. and Ferrari, S. (2010). "Inflated Beta Distributions." &lt;em&gt;Statistical Papers&lt;/em&gt;, 51(1), 111-126.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;The PyMC CustomDist guide:&lt;/strong&gt; &lt;a href="https://www.pymc.io/projects/docs/en/latest/api/distributions/custom.html" rel="noopener noreferrer"&gt;PyMC documentation on custom distributions&lt;/a&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Previous in this series:&lt;/strong&gt; Start with &lt;a href="https://sesen.ai/blog/hierarchical-bayesian-regression-pymc" rel="noopener noreferrer"&gt;Hierarchical Bayesian Regression&lt;/a&gt;, then &lt;a href="https://sesen.ai/blog/bayesian-survival-analysis-pymc" rel="noopener noreferrer"&gt;Bayesian Survival Analysis&lt;/a&gt; for the progression from built-in distributions to custom likelihoods.&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Interactive Tools
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/distribution-explorer" rel="noopener noreferrer"&gt;Distribution Explorer&lt;/a&gt; — Visualise the Beta distribution and other families used in this model&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/bayes-theorem-calculator" rel="noopener noreferrer"&gt;Bayes' Theorem Calculator&lt;/a&gt; — Explore Bayesian reasoning interactively&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Related Posts
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/hierarchical-bayesian-regression-pymc" rel="noopener noreferrer"&gt;Hierarchical Bayesian Regression with PyMC: When Groups Share Strength&lt;/a&gt; — Partial pooling and group-level priors in PyMC&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/bayesian-survival-analysis-pymc" rel="noopener noreferrer"&gt;Bayesian Survival Analysis with PyMC: Modelling Customer Churn&lt;/a&gt; — Another custom likelihood built in PyMC&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/from-mle-to-bayesian-inference" rel="noopener noreferrer"&gt;From MLE to Bayesian Inference&lt;/a&gt; — Why we use priors and posteriors instead of point estimates&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/mcmc-metropolis-hastings-island-hopping-guide" rel="noopener noreferrer"&gt;MCMC Metropolis-Hastings: An Island-Hopping Guide&lt;/a&gt; — The sampling engine behind PyMC&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Frequently Asked Questions
&lt;/h2&gt;

&lt;h3&gt;
  
  
  When should I use a One-Inflated Beta model instead of logistic regression?
&lt;/h3&gt;

&lt;p&gt;Use OIB when your outcome is a fraction between 0 and 1 with a spike at the boundary value of 1. Logistic regression discards the partial repayment information by collapsing everything into a binary label. OIB preserves both the probability of full repayment and the distribution of partial repayments, giving you richer predictions.&lt;/p&gt;

&lt;h3&gt;
  
  
  Why use pm.Potential instead of pm.CustomDist for the likelihood?
&lt;/h3&gt;

&lt;p&gt;The pm.CustomDist approach evaluates both branches of the piecewise likelihood for every observation during automatic differentiation. When the Beta density is evaluated at the boundary value of 1.0, it returns negative infinity, which corrupts the NUTS sampler gradients and causes 100% divergences. Splitting observations with pm.Potential avoids evaluating the Beta density at the boundary entirely.&lt;/p&gt;

&lt;h3&gt;
  
  
  What is the mean-precision parameterisation of the Beta distribution?
&lt;/h3&gt;

&lt;p&gt;Instead of the standard shape parameters alpha and beta, the mean-precision form uses mu (the mean, between 0 and 1) and phi (the precision, controlling spread). This is more interpretable: mu directly tells you the expected partial repayment fraction, while phi tells you how concentrated the distribution is around that mean. The standard parameters are recovered as alpha = mu * phi and beta = (1 - mu) * phi.&lt;/p&gt;

&lt;h3&gt;
  
  
  How do I check whether the OIB model fits my data well?
&lt;/h3&gt;

&lt;p&gt;Generate posterior predictive samples by drawing from the fitted model and comparing the resulting distribution to the observed data. The key check is whether the model reproduces both the spike at 1.0 (the proportion of fully repaid loans) and the shape of the continuous partial repayment distribution. If either component is mismatched, the model needs adjustment.&lt;/p&gt;

&lt;h3&gt;
  
  
  Can this model handle spikes at both 0 and 1?
&lt;/h3&gt;

&lt;p&gt;Yes, but you would need a Zero-and-One-Inflated Beta (ZOIB) model. This adds a third mixture component for the spike at zero, with its own probability parameter. The piecewise likelihood gains a third branch, but the pm.Potential implementation pattern remains the same: split observations into three groups and add each group's log-likelihood separately.&lt;/p&gt;

</description>
      <category>bayesian</category>
      <category>probabilistic</category>
      <category>pymc</category>
      <category>customlikelihood</category>
    </item>
    <item>
      <title>Bayesian Survival Analysis with PyMC: Modelling Customer Churn</title>
      <dc:creator>Berkan Sesen</dc:creator>
      <pubDate>Wed, 29 Apr 2026 09:53:05 +0000</pubDate>
      <link>https://dev.to/berkan_sesen/bayesian-survival-analysis-with-pymc-modelling-customer-churn-55n4</link>
      <guid>https://dev.to/berkan_sesen/bayesian-survival-analysis-with-pymc-modelling-customer-churn-55n4</guid>
      <description>&lt;p&gt;Every subscription business lives or dies by churn. Whether it is a B2B SaaS platform tracking annual contracts or a consumer app watching monthly renewals, the question is the same: how long will this customer stay? The data seems straightforward. Some subscribers cancelled after a month, others after a year. But a large share of customers are still active. They have not churned yet, and you do not know when, or whether, they will.&lt;/p&gt;

&lt;p&gt;A colleague suggested dropping them from the analysis. That felt wrong, and it is: ignoring active customers biases your model toward shorter lifetimes, because you only learn from the people who already left.&lt;/p&gt;

&lt;p&gt;The problem has a name: &lt;strong&gt;right-censoring&lt;/strong&gt;. An active customer who signed up 8 months ago tells you something valuable: they survived &lt;em&gt;at least&lt;/em&gt; 8 months. You don't know when (or whether) they'll churn, but that lower bound is real information.&lt;/p&gt;

&lt;p&gt;Survival analysis handles censoring properly. In our &lt;a href="https://sesen.ai/blog/hierarchical-bayesian-regression-pymc" rel="noopener noreferrer"&gt;previous post&lt;/a&gt;, we built hierarchical models in PyMC for grouped regression. This post extends that toolkit with a new ingredient: the ability to learn from incomplete observations.&lt;/p&gt;

&lt;p&gt;By the end, you'll build a Bayesian accelerated failure time (AFT) model in PyMC, handle right-censored data with &lt;code&gt;pm.Potential&lt;/code&gt;, compare Weibull and Log-Logistic distributions, and plot individual survival curves for different customer profiles.&lt;/p&gt;

&lt;h2&gt;
  
  
  Let's Build It
&lt;/h2&gt;

&lt;p&gt;First, let's see the model in action. Click the badge below to open the full interactive notebook:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://colab.research.google.com/github/zhubarb/sesen_ai_ml_tutorials/blob/main/notebooks/bayesian/bayesian_survival_analysis.ipynb" rel="noopener noreferrer"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fcolab.research.google.com%2Fassets%2Fcolab-badge.svg" alt="Open In Colab" width="117" height="20"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;We'll generate synthetic churn data for 1,000 customers, fit a Weibull AFT model, and plot survival curves.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fx3rnp8lwbc8rg52yidut.gif" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fx3rnp8lwbc8rg52yidut.gif" alt="Survival curves for three customer profiles building up as MCMC samples accumulate. Early frames show scattered, uncertain curves; later frames converge to smooth, separated survival functions for high-value, average, and at-risk customers." width="800" height="500"&gt;&lt;/a&gt;&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;pandas&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;pd&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;pymc&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;pytensor.tensor&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;pt&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;arviz&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;az&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;matplotlib.pyplot&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;

&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;seed&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Generate synthetic churn data: 1,000 customers observed over 24 months
&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;1000&lt;/span&gt;
&lt;span class="n"&gt;monthly_spend&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;30&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;clip&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;20&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;250&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;support_tickets&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;poisson&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;astype&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;float&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Standardise covariates
&lt;/span&gt;&lt;span class="n"&gt;spend_std&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;monthly_spend&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="mi"&gt;30&lt;/span&gt;
&lt;span class="n"&gt;tickets_std&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;support_tickets&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;

&lt;span class="c1"&gt;# True AFT parameters (Gumbel / log-Weibull parameterisation)
&lt;/span&gt;&lt;span class="n"&gt;true_alpha&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;array&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="mf"&gt;2.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.3&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;  &lt;span class="c1"&gt;# intercept, spend, tickets
&lt;/span&gt;&lt;span class="n"&gt;true_s&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;0.6&lt;/span&gt;

&lt;span class="c1"&gt;# True log-time: Y = eta + s * W, where W ~ Gumbel(0,1)
&lt;/span&gt;&lt;span class="n"&gt;eta_true&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;true_alpha&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;true_alpha&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;spend_std&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;true_alpha&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;tickets_std&lt;/span&gt;
&lt;span class="n"&gt;log_time_true&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;eta_true&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;true_s&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;gumbel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;time_true&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_time_true&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Administrative censoring at 24 months
&lt;/span&gt;&lt;span class="n"&gt;observation_window&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;24.0&lt;/span&gt;
&lt;span class="n"&gt;observed_time&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;minimum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;time_true&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;observation_window&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;censored&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;time_true&lt;/span&gt; &lt;span class="o"&gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;observation_window&lt;/span&gt;  &lt;span class="c1"&gt;# True = still active
&lt;/span&gt;&lt;span class="n"&gt;log_observed_time&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;observed_time&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Total customers: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Churned: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;~&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; (&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;~&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="o"&gt;%&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;)&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Still active (censored): &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; (&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="o"&gt;%&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;)&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;Total customers: 1000
Churned: 664 (66%)
Still active (censored): 336 (34%)
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Before fitting the Bayesian model, let's look at the empirical survival curve using the Kaplan-Meier estimator. This non-parametric method handles censoring correctly by adjusting the risk set at each event time:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Kaplan-Meier estimator (manual, no extra dependencies)
&lt;/span&gt;&lt;span class="n"&gt;order&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;argsort&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;observed_time&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;times_sorted&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;observed_time&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;order&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;span class="n"&gt;events_sorted&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;~&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;)[&lt;/span&gt;&lt;span class="n"&gt;order&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;astype&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;km_times&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mf"&gt;0.0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;span class="n"&gt;km_survival&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;span class="n"&gt;n_at_risk&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;N&lt;/span&gt;

&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;event&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;zip&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;times_sorted&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;events_sorted&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;event&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="n"&gt;km_survival&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;km_survival&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;n_at_risk&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
        &lt;span class="n"&gt;km_times&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;n_at_risk&lt;/span&gt; &lt;span class="o"&gt;-=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;

&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;subplots&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;km_times&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;km_survival&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;where&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;post&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;#2196F3&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;lw&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_xlabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Months since signup&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_ylabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Survival probability&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_title&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Kaplan-Meier Survival Curve&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_xlim&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;25&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_ylim&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;1.05&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tight_layout&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fca6xhp3ao8ssvpctzm0a.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fca6xhp3ao8ssvpctzm0a.webp" alt="Kaplan-Meier survival curve for the synthetic churn data. The curve drops steadily over 24 months with censoring tick marks visible. About 34% of customers survive past the 24-month observation window." width="800" height="451"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Now let's fit the Weibull AFT model. The key insight: if &lt;code&gt;$T \sim \text{Weibull}$&lt;/code&gt;, then &lt;code&gt;$Y = \log T$&lt;/code&gt; follows a Gumbel distribution. So we model log-time with a Gumbel likelihood, which lets us write the linear predictor naturally:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;gumbel_log_sf&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;Log survival function of the Gumbel distribution.&lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log1p&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="p"&gt;)))&lt;/span&gt;

&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Model&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;weibull_aft&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="c1"&gt;# Location coefficients (priors match original code: Normal(0, 2))
&lt;/span&gt;    &lt;span class="n"&gt;alpha&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;alpha&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Scale parameter (must be positive)
&lt;/span&gt;    &lt;span class="n"&gt;log_s&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;log_s&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;s&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Deterministic&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;s&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_s&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

    &lt;span class="c1"&gt;# Linear predictor for log-time
&lt;/span&gt;    &lt;span class="n"&gt;eta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;spend_std&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;tickets_std&lt;/span&gt;

    &lt;span class="c1"&gt;# Uncensored customers: standard Gumbel likelihood
&lt;/span&gt;    &lt;span class="n"&gt;y_obs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Gumbel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;y_obs&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;eta&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;~&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                       &lt;span class="n"&gt;observed&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;log_observed_time&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;~&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;

    &lt;span class="c1"&gt;# Censored customers: survival function via pm.Potential
&lt;/span&gt;    &lt;span class="n"&gt;y_cens&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Potential&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;y_cens&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="nf"&gt;gumbel_log_sf&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_observed_time&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;eta&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

    &lt;span class="c1"&gt;# Sample the posterior
&lt;/span&gt;    &lt;span class="n"&gt;trace&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tune&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cores&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;chains&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                      &lt;span class="n"&gt;random_seed&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;target_accept&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.9&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;az&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;summary&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;trace&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;var_names&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;alpha&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;s&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;You just fit a Bayesian survival model that properly handles censored customers. The &lt;code&gt;alpha&lt;/code&gt; coefficients tell you how each covariate affects time-to-churn: positive means longer survival, negative means faster churn. And unlike a point estimate from &lt;a href="https://sesen.ai/blog/maximum-likelihood-estimation-from-scratch" rel="noopener noreferrer"&gt;maximum likelihood&lt;/a&gt;, you get full posterior distributions over every parameter.&lt;/p&gt;

&lt;h2&gt;
  
  
  What Just Happened?
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Right-Censoring: Learning from Incomplete Data
&lt;/h3&gt;

&lt;p&gt;The 336 active customers in our data didn't churn during the 24-month observation window. For each one, we know they survived &lt;em&gt;at least&lt;/em&gt; 24 months, but not how much longer they'll stay. This is &lt;strong&gt;right-censoring&lt;/strong&gt;: the true event time is somewhere to the right of what we observed.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fyjqswh6ctfl4cr46c73q.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fyjqswh6ctfl4cr46c73q.webp" alt="Timeline diagram showing 8 example customers. Five lines end with a red X (churn event) at various times. Three lines extend to the 24-month boundary and end with a green arrow (still active, censored). The observation window is shaded." width="800" height="412"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Standard regression would force you to either drop censored customers (biasing estimates downward) or code them as churning at 24 months (also biased). Survival analysis treats the two types of observation differently in the likelihood.&lt;/p&gt;

&lt;p&gt;For a churned customer at time &lt;code&gt;$t_i$&lt;/code&gt;, the likelihood contribution is the probability density &lt;code&gt;$f(t_i)$&lt;/code&gt;: we observed this exact event time. For a censored customer observed until time &lt;code&gt;$c_i$&lt;/code&gt;, the contribution is the survival probability &lt;code&gt;$S(c_i) = P(T &amp;gt; c_i)$&lt;/code&gt;: all we know is they lasted at least this long.&lt;/p&gt;

&lt;p&gt;The total log-likelihood combines both pieces:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Cell%28%255Ctheta%29%2520%253D%2520%255Csum_%257Bi%2520%255Cin%2520%255Ctext%257Buncensored%257D%257D%2520%255Clog%2520f%28t_i%2520%255Cmid%2520%255Ctheta%29%2520%255C%253B%252B%255C%253B%2520%255Csum_%257Bi%2520%255Cin%2520%255Ctext%257Bcensored%257D%257D%2520%255Clog%2520S%28c_i%2520%255Cmid%2520%255Ctheta%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Cell%28%255Ctheta%29%2520%253D%2520%255Csum_%257Bi%2520%255Cin%2520%255Ctext%257Buncensored%257D%257D%2520%255Clog%2520f%28t_i%2520%255Cmid%2520%255Ctheta%29%2520%255C%253B%252B%255C%253B%2520%255Csum_%257Bi%2520%255Cin%2520%255Ctext%257Bcensored%257D%257D%2520%255Clog%2520S%28c_i%2520%255Cmid%2520%255Ctheta%29" alt="equation" width="550" height="54"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;This is exactly how our PyMC model works. The &lt;code&gt;pm.Gumbel&lt;/code&gt; line handles the first sum (uncensored density). The &lt;code&gt;pm.Potential&lt;/code&gt; line handles the second sum (censored survival).&lt;/p&gt;

&lt;h3&gt;
  
  
  Why Gumbel? The Weibull-Gumbel Connection
&lt;/h3&gt;

&lt;p&gt;The Weibull distribution is the workhorse of survival analysis because it models flexible hazard rates: increasing, decreasing, or constant over time. But working with the Weibull directly is numerically awkward for regression.&lt;/p&gt;

&lt;p&gt;Here's the trick. If &lt;code&gt;$T \sim \text{Weibull}(k, \lambda)$&lt;/code&gt;, then &lt;code&gt;$Y = \log T$&lt;/code&gt; follows a Gumbel distribution:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DY%2520%253D%2520%255Clog%2520T%2520%253D%2520%255Cmu%2520%252B%2520%255Csigma%2520%255Ccdot%2520W%252C%2520%255Cquad%2520W%2520%255Csim%2520%255Ctext%257BGumbel%257D%280%252C%25201%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DY%2520%253D%2520%255Clog%2520T%2520%253D%2520%255Cmu%2520%252B%2520%255Csigma%2520%255Ccdot%2520W%252C%2520%255Cquad%2520W%2520%255Csim%2520%255Ctext%257BGumbel%257D%280%252C%25201%29" alt="equation" width="470" height="25"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;where &lt;code&gt;$\mu = \log \lambda$&lt;/code&gt; is the location and &lt;code&gt;$\sigma = 1/k$&lt;/code&gt; is the scale. This is the &lt;strong&gt;accelerated failure time&lt;/strong&gt; (AFT) formulation: covariates shift &lt;code&gt;$\mu$&lt;/code&gt;, effectively accelerating or decelerating time. We write the linear predictor as:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Ceta_i%2520%253D%2520%255Calpha_0%2520%252B%2520%255Calpha_1%2520%255Ccdot%2520%255Ctext%257Bspend%257D_i%2520%252B%2520%255Calpha_2%2520%255Ccdot%2520%255Ctext%257Btickets%257D_i" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Ceta_i%2520%253D%2520%255Calpha_0%2520%252B%2520%255Calpha_1%2520%255Ccdot%2520%255Ctext%257Bspend%257D_i%2520%252B%2520%255Calpha_2%2520%255Ccdot%2520%255Ctext%257Btickets%257D_i" alt="equation" width="368" height="23"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;A positive &lt;code&gt;$\alpha_1$&lt;/code&gt; means higher spending shifts log-time to the right (longer survival). A negative &lt;code&gt;$\alpha_2$&lt;/code&gt; means more support tickets shift it left (faster churn). The coefficients have a direct interpretation: a one-unit increase in &lt;code&gt;$x_j$&lt;/code&gt; multiplies the median survival time by &lt;code&gt;$\exp(\alpha_j)$&lt;/code&gt;.&lt;/p&gt;

&lt;h3&gt;
  
  
  &lt;code&gt;pm.Potential&lt;/code&gt;: Telling PyMC About Partial Information
&lt;/h3&gt;

&lt;p&gt;In our &lt;a href="https://sesen.ai/blog/hierarchical-bayesian-regression-pymc" rel="noopener noreferrer"&gt;hierarchical regression post&lt;/a&gt;, every observation contributed a full likelihood term through &lt;code&gt;pm.Normal(..., observed=y)&lt;/code&gt;. Censored observations are different: they don't have a fully observed outcome. They only contribute through the survival function.&lt;/p&gt;

&lt;p&gt;&lt;code&gt;pm.Potential('name', value)&lt;/code&gt; adds &lt;code&gt;value&lt;/code&gt; directly to the model's log-posterior. For censored data, we pass the log-survival probability:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;y_cens&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Potential&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;y_cens&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="nf"&gt;gumbel_log_sf&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_observed_time&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;eta&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Think of it this way. For a churned customer, we say "we observed them leave at time &lt;code&gt;$t$&lt;/code&gt;" (standard likelihood). For an active customer, we say "all we know is they're still here after &lt;code&gt;$c$&lt;/code&gt; months" (survival function).&lt;/p&gt;

&lt;h3&gt;
  
  
  MCMC Diagnostics
&lt;/h3&gt;

&lt;p&gt;Before trusting the results, verify the sampler converged:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;az&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot_trace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;trace&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;var_names&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;alpha&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;s&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tight_layout&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fsfgg84yj6ev5dueud1cb.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fsfgg84yj6ev5dueud1cb.webp" alt="ArviZ trace plots for the Weibull AFT model. Top row: alpha posteriors (intercept near 2.5, spend coefficient near 0.4, tickets coefficient near −0.3) with MCMC traces. Bottom row: scale parameter s centred near 0.63. All four chains mix well with stable, overlapping traces." width="800" height="278"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Check the same three diagnostics we covered in the &lt;a href="https://sesen.ai/blog/hierarchical-bayesian-regression-pymc" rel="noopener noreferrer"&gt;hierarchical regression post&lt;/a&gt;: chains should look like "hairy caterpillars" (good mixing), R-hat below 1.01 (convergence), and effective sample size above 400 per chain (low autocorrelation).&lt;/p&gt;

&lt;h3&gt;
  
  
  Survival Curves from the Posterior
&lt;/h3&gt;

&lt;p&gt;The payoff of a Bayesian AFT model is individual survival curves with uncertainty bands. For any customer profile, we compute the survival probability at each time point across all posterior samples:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;t_grid&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;linspace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;36&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;200&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;log_t_grid&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;t_grid&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Extract posterior samples
&lt;/span&gt;&lt;span class="n"&gt;alpha_post&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;trace&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;posterior&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;alpha&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;reshape&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;s_post&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;trace&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;posterior&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;s&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;flatten&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

&lt;span class="n"&gt;profiles&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;High-value (spend +1.5σ, tickets −1σ)&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;1.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;#2196F3&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Average customer&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;                        &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;0.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;  &lt;span class="mf"&gt;0.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;#FF9800&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;At-risk (spend −1.5σ, tickets +2σ)&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;     &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;1.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;  &lt;span class="mf"&gt;2.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;#F44336&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;span class="p"&gt;}&lt;/span&gt;

&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;subplots&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;sp&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tk&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;profiles&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;items&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
    &lt;span class="n"&gt;eta_post&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;alpha_post&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;alpha_post&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;sp&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;alpha_post&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;tk&lt;/span&gt;
    &lt;span class="n"&gt;survival&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;eta_post&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;t_grid&lt;/span&gt;&lt;span class="p"&gt;)))&lt;/span&gt;
    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;eta_post&lt;/span&gt;&lt;span class="p"&gt;)):&lt;/span&gt;
        &lt;span class="n"&gt;z&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_t_grid&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;eta_post&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;s_post&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
        &lt;span class="n"&gt;survival&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;z&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
    &lt;span class="n"&gt;mean_surv&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;survival&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;lower&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;percentile&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;survival&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;upper&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;percentile&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;survival&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;97&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;t_grid&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mean_surv&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;lw&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fill_between&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;t_grid&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;lower&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;upper&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.15&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_xlabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Months since signup&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_ylabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Survival probability&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_title&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Predicted Survival Curves by Customer Profile&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;legend&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;loc&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;upper right&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;fontsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;9&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_xlim&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;36&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_ylim&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;1.05&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tight_layout&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Ffghc87rg33vegsxfkjrm.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Ffghc87rg33vegsxfkjrm.webp" alt="Survival curves for three customer profiles with 94% HDI bands. The high-value customer (blue) stays above 85% survival at 24 months. The average customer (orange) crosses 50% around 13 months. The at-risk customer (red) drops below 20% by month 10." width="800" height="545"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Each curve shows the model's predicted probability that a customer with those characteristics survives beyond a given time. The high-value customer has a much flatter curve: their predicted median lifetime exceeds 36 months. The at-risk customer (low spend, many support tickets) has a steep drop-off with a median around 5 months.&lt;/p&gt;

&lt;p&gt;Notice the uncertainty bands widen at longer times, especially for the at-risk profile. Fewer customers with those characteristics survive that long, so the model has less data to constrain the prediction.&lt;/p&gt;

&lt;h2&gt;
  
  
  Going Deeper
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Covariates in the Scale Too
&lt;/h3&gt;

&lt;p&gt;The model above uses a constant scale parameter &lt;code&gt;$s$&lt;/code&gt; for all customers. The original code I adapted goes further by making the scale covariate-dependent:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Ds_i%2520%253D%2520%255Cexp%255C%21%255Cleft%28%255Crho_0%2520%252B%2520%255Crho_1%2520%255Ccdot%2520x_%257Bi1%257D%2520%252B%2520%255Crho_2%2520%255Ccdot%2520x_%257Bi2%257D%255Cright%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Ds_i%2520%253D%2520%255Cexp%255C%21%255Cleft%28%255Crho_0%2520%252B%2520%255Crho_1%2520%255Ccdot%2520x_%257Bi1%257D%2520%252B%2520%255Crho_2%2520%255Ccdot%2520x_%257Bi2%257D%255Cright%29" alt="equation" width="328" height="25"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;This means the &lt;em&gt;shape&lt;/em&gt; of the Weibull hazard varies across customers. A customer might have both a longer expected lifetime (larger &lt;code&gt;$\eta$&lt;/code&gt;) and more predictable survival (smaller &lt;code&gt;$s$&lt;/code&gt;). In PyMC:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Model&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;weibull_aft_hetero&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="c1"&gt;# Location coefficients
&lt;/span&gt;    &lt;span class="n"&gt;alpha&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;alpha&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="c1"&gt;# Scale coefficients (matching original code's rho priors)
&lt;/span&gt;    &lt;span class="n"&gt;rho&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;rho&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="n"&gt;eta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;spend_std&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;tickets_std&lt;/span&gt;
    &lt;span class="n"&gt;s&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;rho&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;rho&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;spend_std&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;rho&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;tickets_std&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="n"&gt;y_obs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Gumbel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;y_obs&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;eta&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;~&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;~&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
                       &lt;span class="n"&gt;observed&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;log_observed_time&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;~&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
    &lt;span class="n"&gt;y_cens&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Potential&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;y_cens&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="nf"&gt;gumbel_log_sf&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_observed_time&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;eta&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;]))&lt;/span&gt;

    &lt;span class="n"&gt;trace_hetero&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tune&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cores&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;chains&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                             &lt;span class="n"&gt;random_seed&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;target_accept&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.9&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This is faithful to the &lt;code&gt;aft_model_factory_explicit&lt;/code&gt; function in the original code, which uses separate &lt;code&gt;rho_interc&lt;/code&gt;, &lt;code&gt;rho_coeff1&lt;/code&gt;, &lt;code&gt;rho_coeff2&lt;/code&gt; parameters for the Gumbel scale. The &lt;code&gt;exp&lt;/code&gt; link ensures &lt;code&gt;$s_i &amp;gt; 0$&lt;/code&gt; for every customer.&lt;/p&gt;

&lt;h3&gt;
  
  
  Weibull vs Log-Logistic: Which Tail Shape?
&lt;/h3&gt;

&lt;p&gt;The Weibull model assumes the hazard rate is &lt;em&gt;monotonic&lt;/em&gt;: always increasing, always decreasing, or constant. But some churn patterns are non-monotonic. New users might have high churn risk initially (they haven't found value yet), which drops as they engage, then rises again as they outgrow the product.&lt;/p&gt;

&lt;p&gt;The &lt;strong&gt;Log-Logistic&lt;/strong&gt; AFT model handles this. In log-time, the Log-Logistic corresponds to a Logistic distribution, just as the Weibull corresponds to a Gumbel. The swap is straightforward:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;logistic_log_sf&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;Log survival function of the Logistic distribution.&lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;pt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;softplus&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Model&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;loglogistic_aft&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;alpha&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;alpha&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;log_s&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;log_s&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;s&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Deterministic&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;s&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_s&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

    &lt;span class="n"&gt;eta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;spend_std&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;tickets_std&lt;/span&gt;

    &lt;span class="n"&gt;y_obs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Logistic&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;y_obs&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;eta&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;~&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                         &lt;span class="n"&gt;observed&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;log_observed_time&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;~&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
    &lt;span class="n"&gt;y_cens&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Potential&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;y_cens&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="nf"&gt;logistic_log_sf&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_observed_time&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;eta&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;censored&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

    &lt;span class="n"&gt;trace_ll&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tune&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cores&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;chains&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                         &lt;span class="n"&gt;random_seed&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;target_accept&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.9&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Feplgs0eoeu46xxyayhj2.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Feplgs0eoeu46xxyayhj2.webp" alt="Two-panel comparison of Weibull (left) and Log-Logistic (right) survival curves for the average customer. The Weibull curve decays smoothly following a stretched exponential. The Log-Logistic curve has a heavier tail, decaying more slowly at longer times." width="800" height="345"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Compare the two models using LOO-CV (leave-one-out cross-validation) with ArviZ:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;weibull_loo&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;az&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;loo&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;trace&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ll_loo&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;az&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;loo&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;trace_ll&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;az&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;compare&lt;/span&gt;&lt;span class="p"&gt;({&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Weibull&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;trace&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Log-Logistic&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;trace_ll&lt;/span&gt;&lt;span class="p"&gt;}))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Since our synthetic data was generated from a Weibull distribution, the Weibull model should win. On real data, the comparison often reveals which tail shape better captures your customers' churn dynamics.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Cox Proportional Hazards Alternative
&lt;/h3&gt;

&lt;p&gt;Survival analysis has a dominant semi-parametric approach: the Cox proportional hazards (PH) model. It doesn't assume a distribution for the baseline hazard, only that covariates multiply the hazard by a constant factor. This flexibility made it ubiquitous in clinical trials.&lt;/p&gt;

&lt;p&gt;So why choose a parametric Bayesian AFT model? Three reasons:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Full predictive distributions.&lt;/strong&gt; The Cox model gives hazard ratios, but producing survival curves requires additional estimation of the baseline hazard. Our Bayesian AFT model gives survival curves with uncertainty bands directly from the posterior.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Small samples and heavy censoring.&lt;/strong&gt; With many active customers, the Cox model's partial likelihood can be imprecise. Bayesian priors stabilise estimates, especially for rare covariates. This is the same principle of "borrowing strength" we explored in the &lt;a href="https://sesen.ai/blog/hierarchical-bayesian-regression-pymc" rel="noopener noreferrer"&gt;hierarchical regression post&lt;/a&gt;.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Natural extension.&lt;/strong&gt; PyMC models compose freely. Adding group structure (churn by subscription tier), time-varying covariates, or custom likelihoods is straightforward. The next post in this series demonstrates exactly this with a one-inflated Beta regression.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fu8wt5fam678v77rhkb69.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fu8wt5fam678v77rhkb69.webp" alt="Flow diagram showing the AFT model structure. Covariates (monthly spend, support tickets) feed into two linear predictors: one for the location parameter eta and one for the scale parameter s. These combine into a Gumbel distribution for log-time, which maps to a Weibull distribution for actual survival time. Censored and uncensored paths split at the likelihood." width="800" height="446"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3&gt;
  
  
  When NOT to Use Bayesian AFT
&lt;/h3&gt;

&lt;p&gt;If the proportional hazards assumption holds and your dataset is large (tens of thousands of events), the Cox model is faster and assumption-lighter. If you have time-varying covariates that change during a customer's lifetime (e.g., monthly usage patterns), the standard AFT formulation doesn't handle them naturally; you'd need a piecewise approach or a joint model.&lt;/p&gt;

&lt;p&gt;Computational cost matters too. Our 1,000-customer model samples in a few minutes, but production datasets with millions of rows would require approximations like variational inference or mini-batch MCMC.&lt;/p&gt;

&lt;h2&gt;
  
  
  Where This Comes From
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Cox (1972): Proportional Hazards
&lt;/h3&gt;

&lt;p&gt;The modern era of survival analysis began with David Cox's 1972 paper "Regression Models and Life-Tables." Cox introduced the proportional hazards model:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Dh%28t%2520%255Cmid%2520%255Cmathbf%257Bx%257D%29%2520%253D%2520h_0%28t%29%2520%255Cexp%28%255Cboldsymbol%257B%255Cbeta%257D%255E%255Ctop%2520%255Cmathbf%257Bx%257D%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Dh%28t%2520%255Cmid%2520%255Cmathbf%257Bx%257D%29%2520%253D%2520h_0%28t%29%2520%255Cexp%28%255Cboldsymbol%257B%255Cbeta%257D%255E%255Ctop%2520%255Cmathbf%257Bx%257D%29" alt="equation" width="267" height="29"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;where &lt;code&gt;$h_0(t)$&lt;/code&gt; is an unspecified baseline hazard. The genius was leaving &lt;code&gt;$h_0$&lt;/code&gt; unspecified and estimating &lt;code&gt;$\boldsymbol{\beta}$&lt;/code&gt; through the &lt;strong&gt;partial likelihood&lt;/strong&gt;, which depends only on the order of events, not their exact times. This paper has been cited over 65,000 times and remains the most-used method in clinical trials.&lt;/p&gt;

&lt;blockquote&gt;
&lt;p&gt;"The important practical point is that [the partial likelihood] does not require specification of &lt;code&gt;$h_0(t)$&lt;/code&gt;." (Cox, 1972)&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;Our AFT model takes a different path: we specify a distribution (Weibull or Log-Logistic), which enables direct time predictions. This parametric assumption is both a strength (more powerful inference when correct) and a weakness (biased inference when wrong).&lt;/p&gt;

&lt;h3&gt;
  
  
  Buckley and James (1979): Accelerated Failure Time
&lt;/h3&gt;

&lt;p&gt;The AFT framework was formalised by Miles Buckley and Ian James in 1979. Their key insight was that the AFT model has a direct linear regression interpretation:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Clog%2520T_i%2520%253D%2520%255Cmathbf%257Bx%257D_i%255E%255Ctop%2520%255Cboldsymbol%257B%255Calpha%257D%2520%252B%2520%255Csigma%2520%255Cepsilon_i" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Clog%2520T_i%2520%253D%2520%255Cmathbf%257Bx%257D_i%255E%255Ctop%2520%255Cboldsymbol%257B%255Calpha%257D%2520%252B%2520%255Csigma%2520%255Cepsilon_i" alt="equation" width="199" height="28"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;where &lt;code&gt;$\epsilon_i$&lt;/code&gt; follows a known distribution (Gumbel for Weibull, Logistic for Log-Logistic). The coefficients &lt;code&gt;$\alpha_j$&lt;/code&gt; have a clean meaning: a one-unit increase in &lt;code&gt;$x_j$&lt;/code&gt; multiplies the median survival time by &lt;code&gt;$\exp(\alpha_j)$&lt;/code&gt;. This is why it's called "accelerated failure time": covariates speed up or slow down the passage of time.&lt;/p&gt;

&lt;h3&gt;
  
  
  Wei (1992): AFT as an Alternative
&lt;/h3&gt;

&lt;p&gt;L. J. Wei's 1992 paper "The Accelerated Failure Time Model: A Useful Alternative to the Cox Regression Model in Survival Analysis" made the case for AFT models as a practical complement to Cox PH. Wei showed that AFT models are more robust to omitted covariates and provide more interpretable effect sizes.&lt;/p&gt;

&lt;blockquote&gt;
&lt;p&gt;"When the acceleration factor is constant over time, the AFT model provides a simple and clinically meaningful summary of the survival experience." (Wei, 1992)&lt;/p&gt;
&lt;/blockquote&gt;

&lt;h3&gt;
  
  
  Handling Censoring in PyMC
&lt;/h3&gt;

&lt;p&gt;The &lt;code&gt;pm.Potential&lt;/code&gt; approach for censored data follows directly from the likelihood factorisation. For a dataset with observed and censored outcomes:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DL%28%255Ctheta%29%2520%253D%2520%255Cprod_%257Bi%2520%255Cin%2520%255Ctext%257Buncensored%257D%257D%2520f%28t_i%2520%255Cmid%2520%255Ctheta%29%2520%255C%253B%255Cprod_%257Bi%2520%255Cin%2520%255Ctext%257Bcensored%257D%257D%2520S%28c_i%2520%255Cmid%2520%255Ctheta%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DL%28%255Ctheta%29%2520%253D%2520%255Cprod_%257Bi%2520%255Cin%2520%255Ctext%257Buncensored%257D%257D%2520f%28t_i%2520%255Cmid%2520%255Ctheta%29%2520%255C%253B%255Cprod_%257Bi%2520%255Cin%2520%255Ctext%257Bcensored%257D%257D%2520S%28c_i%2520%255Cmid%2520%255Ctheta%29" alt="equation" width="452" height="54"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Taking logs, the uncensored terms give the standard log-likelihood (handled by &lt;code&gt;pm.Gumbel&lt;/code&gt; or &lt;code&gt;pm.Logistic&lt;/code&gt;). The censored terms give log-survival values (handled by &lt;code&gt;pm.Potential&lt;/code&gt;). This pattern appears throughout the &lt;a href="https://www.pymc.io/projects/examples/en/latest/survival_analysis/weibull_aft.html" rel="noopener noreferrer"&gt;PyMC survival analysis examples&lt;/a&gt; and extends naturally to interval censoring and left censoring by swapping the survival function for the appropriate probability term.&lt;/p&gt;

&lt;h3&gt;
  
  
  Further Reading
&lt;/h3&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;The proportional hazards model:&lt;/strong&gt; Cox, D. R. (1972). "Regression Models and Life-Tables." &lt;em&gt;Journal of the Royal Statistical Society: Series B&lt;/em&gt;, 34(2), 187-220.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;The AFT framework:&lt;/strong&gt; Buckley, J. &amp;amp; James, I. (1979). "Linear Regression with Censored Data." &lt;em&gt;Biometrika&lt;/em&gt;, 66(3), 429-436.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;AFT as a Cox alternative:&lt;/strong&gt; Wei, L. J. (1992). "The Accelerated Failure Time Model." &lt;em&gt;Statistics in Medicine&lt;/em&gt;, 11(14-15), 1871-1879.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;The standard reference:&lt;/strong&gt; Kalbfleisch, J. D. &amp;amp; Prentice, R. L. (2002). &lt;em&gt;The Statistical Analysis of Failure Time Data&lt;/em&gt;, 2nd ed. Wiley.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;PyMC survival example:&lt;/strong&gt; &lt;a href="https://www.pymc.io/projects/examples/en/latest/survival_analysis/weibull_aft.html" rel="noopener noreferrer"&gt;Weibull AFT notebook&lt;/a&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Previous in this series:&lt;/strong&gt; &lt;a href="https://sesen.ai/blog/hierarchical-bayesian-regression-pymc" rel="noopener noreferrer"&gt;Hierarchical Bayesian Regression with PyMC&lt;/a&gt;, which introduces PyMC, partial pooling, and ArviZ diagnostics.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Next in this series:&lt;/strong&gt; Custom likelihoods in PyMC, where we build a one-inflated Beta regression for bounded outcome data.&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Interactive Tools
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/kaplan-meier-calculator" rel="noopener noreferrer"&gt;Kaplan-Meier Calculator&lt;/a&gt; — Estimate survival curves and compare groups interactively&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/medical-stats-calculator" rel="noopener noreferrer"&gt;Medical Statistics Calculator&lt;/a&gt; — Compute sensitivity, specificity, and other diagnostic metrics&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Related Posts
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/hierarchical-bayesian-regression-pymc" rel="noopener noreferrer"&gt;Hierarchical Bayesian Regression with PyMC&lt;/a&gt;: The first post in this PyMC series, covering partial pooling and MCMC diagnostics with ArviZ.&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/mcmc-metropolis-hastings-island-hopping-guide" rel="noopener noreferrer"&gt;MCMC Island Hopping: Understanding Metropolis-Hastings&lt;/a&gt;: How the NUTS sampler that powers PyMC explores high-dimensional posteriors.&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/from-mle-to-bayesian-inference" rel="noopener noreferrer"&gt;From MLE to Bayesian Inference&lt;/a&gt;: The conceptual foundation for priors, posteriors, and why Bayesian estimates outperform point estimates.&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Frequently Asked Questions
&lt;/h2&gt;

&lt;h3&gt;
  
  
  What is right-censoring and why does it matter?
&lt;/h3&gt;

&lt;p&gt;Right-censoring occurs when you know a subject survived at least until a certain time, but not the actual event time. In churn analysis, active customers are right-censored because they have not yet churned. Ignoring them biases your model toward shorter lifetimes, since you only learn from customers who already left. Survival analysis handles censoring properly by using the survival function for these partial observations.&lt;/p&gt;

&lt;h3&gt;
  
  
  What is the difference between the Cox model and an AFT model?
&lt;/h3&gt;

&lt;p&gt;The Cox proportional hazards model is semi-parametric: it leaves the baseline hazard unspecified and estimates how covariates multiply the hazard rate. The accelerated failure time (AFT) model is fully parametric: it assumes a specific distribution (such as Weibull) and models how covariates accelerate or decelerate time to event. AFT coefficients have a direct interpretation as multipliers on median survival time, while Cox coefficients are hazard ratios.&lt;/p&gt;

&lt;h3&gt;
  
  
  What does pm.Potential do in PyMC?
&lt;/h3&gt;

&lt;p&gt;pm.Potential adds an arbitrary log-probability term directly to the model's log-posterior. For censored observations, there is no fully observed outcome to pass to a standard likelihood. Instead, you compute the log-survival probability and add it via pm.Potential, telling PyMC that these customers survived at least this long without specifying when they will actually churn.&lt;/p&gt;

&lt;h3&gt;
  
  
  How do I choose between Weibull and Log-Logistic distributions?
&lt;/h3&gt;

&lt;p&gt;Use Weibull when you expect the hazard rate to be monotonic, either always increasing, always decreasing, or constant over time. Use Log-Logistic when the hazard may be non-monotonic, such as high initial churn that drops as users engage and then rises again later. You can compare the two formally using LOO-CV (leave-one-out cross-validation) in ArviZ.&lt;/p&gt;

&lt;h3&gt;
  
  
  How many customers do I need for a Bayesian survival model?
&lt;/h3&gt;

&lt;p&gt;Bayesian models can work with surprisingly small datasets because priors regularise the estimates, but a practical minimum is a few hundred observations with at least 50 to 100 uncensored events. With heavy censoring (over 80% still active), the model has less information about event times, so you may need a larger sample or more informative priors to get precise estimates.&lt;/p&gt;

&lt;h3&gt;
  
  
  Can I add time-varying covariates to a Bayesian AFT model?
&lt;/h3&gt;

&lt;p&gt;The standard AFT formulation assumes covariates are fixed at baseline and does not naturally handle features that change during a customer's lifetime, such as monthly usage patterns. For time-varying covariates, you would need a piecewise AFT approach that splits each customer's timeline into intervals, or a joint model that links the longitudinal covariate process with the survival outcome.&lt;/p&gt;

</description>
      <category>bayesian</category>
      <category>probabilistic</category>
      <category>survivalanalysis</category>
      <category>pymc</category>
    </item>
    <item>
      <title>Hierarchical Bayesian Regression with PyMC: When Groups Share Strength</title>
      <dc:creator>Berkan Sesen</dc:creator>
      <pubDate>Sun, 26 Apr 2026 12:43:53 +0000</pubDate>
      <link>https://dev.to/berkan_sesen/hierarchical-bayesian-regression-with-pymc-when-groups-share-strength-2hag</link>
      <guid>https://dev.to/berkan_sesen/hierarchical-bayesian-regression-with-pymc-when-groups-share-strength-2hag</guid>
      <description>&lt;p&gt;A multi-line insurer writes auto, home, commercial property, and a dozen other policy types under one roof. Some lines see thousands of claims a year; others might see 50. Every actuary faces the same dilemma: train a separate pricing model for each line and the small ones are pure noise, or pool everything together and pretend a warehouse fire looks like a fender bender. Either way, you lose.&lt;/p&gt;

&lt;p&gt;Hierarchical Bayesian regression offers a third way. Each group gets its own parameters, but those parameters are drawn from a shared population distribution. Groups with plenty of data stay close to their own estimates. Groups with little data get "pulled" toward the population average, borrowing statistical strength from the larger groups. This effect is called &lt;strong&gt;shrinkage&lt;/strong&gt;, and it's one of the most elegant ideas in statistics.&lt;/p&gt;

&lt;p&gt;By the end of this post, you'll build a hierarchical Bayesian regression model in &lt;a href="https://www.pymc.io/" rel="noopener noreferrer"&gt;PyMC&lt;/a&gt;, compare it against pooled and unpooled alternatives, and see shrinkage in action on synthetic insurance data.&lt;/p&gt;

&lt;h2&gt;
  
  
  Let's Build It
&lt;/h2&gt;

&lt;p&gt;First, let's see the hierarchical model in action. Click the badge below to open the full interactive notebook:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://colab.research.google.com/github/zhubarb/sesen_ai_ml_tutorials/blob/main/notebooks/bayesian/hierarchical_bayesian_regression.ipynb" rel="noopener noreferrer"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fcolab.research.google.com%2Fassets%2Fcolab-badge.svg" alt="Open In Colab" width="117" height="20"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;We'll generate synthetic insurance claim data for three policy types with deliberately unbalanced sample sizes, then fit a hierarchical model.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fuj18jcvc1oul1myfok00.gif" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fuj18jcvc1oul1myfok00.gif" alt="The Commercial intercept posterior building up as MCMC samples accumulate. Early frames show a jagged histogram; later frames resolve to a smooth distribution centred near the true intercept value of 9.0." width="800" height="450"&gt;&lt;/a&gt;&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;pandas&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;pd&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;pymc&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;arviz&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;az&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;matplotlib.pyplot&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;

&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;seed&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Three policy types: lots of Auto data, moderate Home, very little Commercial
&lt;/span&gt;&lt;span class="n"&gt;groups&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Auto&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;       &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;n&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mi"&gt;500&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;intercept&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mf"&gt;7.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;slope&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mf"&gt;0.30&lt;/span&gt;&lt;span class="p"&gt;},&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Home&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;       &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;n&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mi"&gt;300&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;intercept&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mf"&gt;8.2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;slope&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mf"&gt;0.50&lt;/span&gt;&lt;span class="p"&gt;},&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Commercial&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;n&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;  &lt;span class="mi"&gt;50&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;intercept&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mf"&gt;9.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;slope&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="mf"&gt;0.70&lt;/span&gt;&lt;span class="p"&gt;},&lt;/span&gt;
&lt;span class="p"&gt;}&lt;/span&gt;

&lt;span class="n"&gt;records&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;name&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;enumerate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;groups&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;items&lt;/span&gt;&lt;span class="p"&gt;()):&lt;/span&gt;
    &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;12&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;1.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;n&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;  &lt;span class="c1"&gt;# log property value (~$160k median)
&lt;/span&gt;    &lt;span class="n"&gt;noise&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;n&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
    &lt;span class="n"&gt;y&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;intercept&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;slope&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mi"&gt;12&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;noise&lt;/span&gt;
    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;j&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;n&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]):&lt;/span&gt;
        &lt;span class="n"&gt;records&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;({&lt;/span&gt;
            &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;policy_type&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;name&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;group_idx&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
            &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;log_property_value&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;j&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;log_claim_severity&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;j&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;
        &lt;span class="p"&gt;})&lt;/span&gt;

&lt;span class="n"&gt;df&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pd&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;DataFrame&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;records&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fvdwzplqpuruf55us0di9.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fvdwzplqpuruf55us0di9.webp" alt="Scatter plot of log claim severity vs log property value, coloured by policy type. Auto (blue, n=500) and Home (orange, n=300) have dense clusters while Commercial (green, n=50) is sparse." width="800" height="493"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Each policy type has a different intercept and slope, but Commercial has just 50 data points. Now let's fit the hierarchical model:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;n_types&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;groups&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;idx&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;df&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;group_idx&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt;
&lt;span class="n"&gt;x_centered&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;df&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;log_property_value&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mi"&gt;12&lt;/span&gt;  &lt;span class="c1"&gt;# center the predictor
&lt;/span&gt;
&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Model&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;hierarchical_model&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="c1"&gt;# Hyperpriors: the "population" distribution that groups are drawn from
&lt;/span&gt;    &lt;span class="n"&gt;mu_alpha&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;mu_alpha&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;sigma_alpha&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;HalfNormal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;sigma_alpha&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;mu_beta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;mu_beta&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;sigma_beta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;HalfNormal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;sigma_beta&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Group-level parameters, drawn from the population
&lt;/span&gt;    &lt;span class="n"&gt;alpha&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;alpha&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;mu_alpha&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;sigma_alpha&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;n_types&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;beta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;beta&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;mu_beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;sigma_beta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;n_types&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Observation noise
&lt;/span&gt;    &lt;span class="n"&gt;sigma&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;HalfNormal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;sigma&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Linear model
&lt;/span&gt;    &lt;span class="n"&gt;mu&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;idx&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;idx&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;x_centered&lt;/span&gt;
    &lt;span class="n"&gt;y_obs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;y_obs&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                      &lt;span class="n"&gt;observed&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;df&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;log_claim_severity&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Sample the posterior
&lt;/span&gt;    &lt;span class="n"&gt;hierarchical_trace&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
        &lt;span class="mi"&gt;1000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tune&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1500&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cores&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;chains&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;random_seed&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;target_accept&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.9&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Summarise the results
&lt;/span&gt;&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;az&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;summary&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hierarchical_trace&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;var_names&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;alpha&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;beta&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;sigma&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;You just estimated three group-specific regression lines (one per policy type) while letting them share statistical strength through a common population distribution. The Commercial group, despite having only 50 claims, gets a stable estimate because it borrows information from Auto and Home.&lt;/p&gt;

&lt;h2&gt;
  
  
  What Just Happened?
&lt;/h2&gt;

&lt;h3&gt;
  
  
  The Three Pooling Strategies
&lt;/h3&gt;

&lt;p&gt;To understand why the hierarchical model is special, let's compare it against the two extreme alternatives.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Complete pooling&lt;/strong&gt; ignores group differences entirely. One intercept, one slope for all 850 data points:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Model&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;pooled_model&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;alpha&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;alpha&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;beta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;beta&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;sigma&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;HalfNormal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;sigma&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="n"&gt;mu&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;x_centered&lt;/span&gt;
    &lt;span class="n"&gt;y_obs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;y_obs&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                      &lt;span class="n"&gt;observed&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;df&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;log_claim_severity&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="n"&gt;pooled_trace&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tune&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cores&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;chains&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_seed&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;strong&gt;No pooling&lt;/strong&gt; treats each group as completely independent. Three separate intercepts, three separate slopes, with no shared information:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Model&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;unpooled_model&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;alpha&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;alpha&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;n_types&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;beta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;beta&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;n_types&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;sigma&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;HalfNormal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;sigma&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="n"&gt;mu&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;idx&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;idx&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;x_centered&lt;/span&gt;
    &lt;span class="n"&gt;y_obs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;y_obs&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                      &lt;span class="n"&gt;observed&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;df&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;log_claim_severity&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;values&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="n"&gt;unpooled_trace&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sample&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tune&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cores&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;chains&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_seed&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fite20um2l4abv9expb61.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fite20um2l4abv9expb61.webp" alt="Three-panel comparison of regression lines. Left: complete pooling (one line through all data, clearly wrong for Commercial). Centre: no pooling (three independent lines, Commercial line is noisy). Right: partial pooling (three lines, Commercial is slightly pulled toward the others)." width="800" height="316"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The comparison reveals the key insight. Complete pooling gives a single line dominated by Auto and Home (which together make up 94% of the data), systematically underestimating Commercial's higher intercept and steeper slope. No pooling gives each group its own line, but Commercial's estimate is noisy because it only has 50 points. Partial pooling (the hierarchical model) sits between the two: each group gets its own line, but the lines are gently pulled toward the population average. Groups with little data get pulled more.&lt;/p&gt;

&lt;h3&gt;
  
  
  How Hyperpriors Create Partial Pooling
&lt;/h3&gt;

&lt;p&gt;The magic ingredient is the &lt;strong&gt;hyperpriors&lt;/strong&gt;: &lt;code&gt;mu_alpha&lt;/code&gt;, &lt;code&gt;sigma_alpha&lt;/code&gt;, &lt;code&gt;mu_beta&lt;/code&gt;, &lt;code&gt;sigma_beta&lt;/code&gt;. These define a "population distribution" from which group-level parameters are drawn.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Calpha_j%2520%255Csim%2520%255Cmathcal%257BN%257D%28%255Cmu_%255Calpha%252C%2520%255Csigma_%255Calpha%255E2%29%2520%255Cquad%2520%255Ctext%257Bfor%2520%257D%2520j%2520%253D%25201%252C%2520%255Cldots%252C%2520J" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Calpha_j%2520%255Csim%2520%255Cmathcal%257BN%257D%28%255Cmu_%255Calpha%252C%2520%255Csigma_%255Calpha%255E2%29%2520%255Cquad%2520%255Ctext%257Bfor%2520%257D%2520j%2520%253D%25201%252C%2520%255Cldots%252C%2520J" alt="equation" width="354" height="29"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Think of &lt;code&gt;$\mu_\alpha$&lt;/code&gt; as the average intercept across all policy types, and &lt;code&gt;$\sigma_\alpha$&lt;/code&gt; as how much the types are allowed to differ. If the data supports large differences, &lt;code&gt;$\sigma_\alpha$&lt;/code&gt; will be large and each group behaves almost independently (like no pooling). If the groups are similar, &lt;code&gt;$\sigma_\alpha$&lt;/code&gt; shrinks and the group estimates collapse toward the population mean (like complete pooling).&lt;/p&gt;

&lt;p&gt;The sampler learns &lt;code&gt;$\sigma_\alpha$&lt;/code&gt; from the data itself. You don't have to choose between pooling and no pooling; the model figures out the right amount of sharing automatically.&lt;/p&gt;

&lt;h3&gt;
  
  
  Shrinkage: The Key Insight
&lt;/h3&gt;

&lt;p&gt;Shrinkage is the defining feature of hierarchical models. Compare each group's raw sample mean (what you'd get from no pooling) to its hierarchical posterior mean:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F6di4mv60vkp1unyx0ahx.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F6di4mv60vkp1unyx0ahx.webp" alt="Shrinkage plot showing raw group means (circles) and hierarchical posterior means (triangles) for each policy type's intercept. The horizontal dashed line marks the population mean. Commercial moves the most toward the population mean; Auto barely moves." width="800" height="472"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Commercial's intercept gets pulled the most toward the population mean, because it has the least data and therefore the most uncertainty. Auto barely moves, because 500 data points leave little room for the prior to override the evidence. This is exactly the Bayesian compromise between prior and data that we explored in &lt;a href="https://sesen.ai/blog/from-mle-to-bayesian-inference" rel="noopener noreferrer"&gt;From MLE to Bayesian Inference&lt;/a&gt;.&lt;/p&gt;

&lt;h3&gt;
  
  
  MCMC Diagnostics with ArviZ
&lt;/h3&gt;

&lt;p&gt;Before trusting the results, we need to verify the sampler converged. ArviZ provides the standard toolkit:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;az&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot_trace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hierarchical_trace&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;var_names&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;alpha&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;beta&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;sigma&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tight_layout&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Ft1tllij7tg3kzv6mxw9b.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Ft1tllij7tg3kzv6mxw9b.webp" alt="ArviZ trace plots for the hierarchical model. Top row: alpha (three group posteriors and traces). Middle row: beta (three group posteriors and traces). Bottom row: sigma (shared noise parameter). All chains show stable mixing." width="800" height="570"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Three things to check:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Trace mixing&lt;/strong&gt;: The chains should look like "hairy caterpillars", bouncing randomly around a stable mean. If a chain gets stuck or drifts, something is wrong.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;R-hat&lt;/strong&gt; (the Gelman-Rubin statistic): Should be below 1.01 for every parameter. Values above 1.1 indicate the chains haven't converged to the same distribution.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Effective sample size (ESS)&lt;/strong&gt;: Should be at least 400 per chain. Low ESS means the samples are highly autocorrelated and the posterior estimates are unreliable.
&lt;/li&gt;
&lt;/ol&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;summary&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;az&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;summary&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hierarchical_trace&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;var_names&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;alpha&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;beta&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;sigma&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;summary&lt;/span&gt;&lt;span class="p"&gt;[[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;mean&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;sd&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;hdi_3%&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;hdi_97%&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;r_hat&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;ess_bulk&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]])&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;If you've worked through our &lt;a href="https://sesen.ai/blog/mcmc-metropolis-hastings-island-hopping-guide" rel="noopener noreferrer"&gt;MCMC Metropolis-Hastings&lt;/a&gt; tutorial, you'll recognise the core idea: the sampler explores the posterior by proposing moves and accepting or rejecting them. PyMC uses the NUTS sampler (No U-Turn Sampler), a sophisticated variant of Hamiltonian Monte Carlo that automatically tunes step sizes and trajectory lengths.&lt;/p&gt;

&lt;h2&gt;
  
  
  Going Deeper
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Why Not a Normal Likelihood?
&lt;/h3&gt;

&lt;p&gt;The model above uses a Normal likelihood, which assumes claim amounts are symmetric around the mean. In practice, insurance claims are &lt;strong&gt;heavy-tailed&lt;/strong&gt;: most claims are small, but a few are enormous. The original code I adapted for this tutorial used a &lt;a href="https://en.wikipedia.org/wiki/Laplace_distribution" rel="noopener noreferrer"&gt;Laplace likelihood&lt;/a&gt; to handle this:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Dy_i%2520%255Csim%2520%255Ctext%257BLaplace%257D%28%255Cmu_i%252C%2520b_i%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Dy_i%2520%255Csim%2520%255Ctext%257BLaplace%257D%28%255Cmu_i%252C%2520b_i%29" alt="equation" width="202" height="25"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The Laplace distribution has heavier tails than the Normal and is more robust to outliers. In PyMC, swapping the likelihood is a single line change:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Replace:  pm.Normal('y_obs', mu=mu, sigma=sigma, observed=y)
# With:     pm.Laplace('y_obs', mu=mu, b=b, observed=y)
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  Modelling the Spread Too: Heteroscedastic Regression
&lt;/h3&gt;

&lt;p&gt;The original code goes further. It models &lt;strong&gt;both&lt;/strong&gt; the location &lt;code&gt;$\mu$&lt;/code&gt; and the scale &lt;code&gt;$b$&lt;/code&gt; of the Laplace distribution as functions of the covariates. This is heteroscedastic regression: the amount of noise varies across observations.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Cmu_i%2520%253D%2520%255Cexp%255C%21%255Cleft%28%255Cbeta_0%255E%257B%28j%29%257D%2520%252B%2520%255Cbeta_1%255E%257B%28j%29%257D%2520%255Clog%2520x_%257Bi1%257D%2520%252B%2520%255Ccdots%2520%252B%2520%255Cbeta_4%255E%257B%28j%29%257D%2520%255Clog%2520x_%257Bi4%257D%255Cright%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Cmu_i%2520%253D%2520%255Cexp%255C%21%255Cleft%28%255Cbeta_0%255E%257B%28j%29%257D%2520%252B%2520%255Cbeta_1%255E%257B%28j%29%257D%2520%255Clog%2520x_%257Bi1%257D%2520%252B%2520%255Ccdots%2520%252B%2520%255Cbeta_4%255E%257B%28j%29%257D%2520%255Clog%2520x_%257Bi4%257D%255Cright%29" alt="equation" width="490" height="45"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Db_i%2520%253D%2520%255Cexp%255C%21%255Cleft%28%255Cgamma_0%255E%257B%28j%29%257D%2520%252B%2520%255Cgamma_1%255E%257B%28j%29%257D%2520%255Clog%2520x_%257Bi1%257D%2520%252B%2520%255Ccdots%2520%252B%2520%255Cgamma_4%255E%257B%28j%29%257D%2520%255Clog%2520x_%257Bi4%257D%255Cright%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Db_i%2520%253D%2520%255Cexp%255C%21%255Cleft%28%255Cgamma_0%255E%257B%28j%29%257D%2520%252B%2520%255Cgamma_1%255E%257B%28j%29%257D%2520%255Clog%2520x_%257Bi1%257D%2520%252B%2520%255Ccdots%2520%252B%2520%255Cgamma_4%255E%257B%28j%29%257D%2520%255Clog%2520x_%257Bi4%257D%255Cright%29" alt="equation" width="481" height="45"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The &lt;code&gt;$\exp$&lt;/code&gt; ensures both &lt;code&gt;$\mu$&lt;/code&gt; and &lt;code&gt;$b$&lt;/code&gt; are positive (claim severity can't be negative). Each &lt;code&gt;$\beta$&lt;/code&gt; and &lt;code&gt;$\gamma$&lt;/code&gt; coefficient gets its own hierarchical structure:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Model&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;full_model&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;n_groups&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;  &lt;span class="c1"&gt;# policy types
&lt;/span&gt;
    &lt;span class="c1"&gt;# Hyperpriors for intercept
&lt;/span&gt;    &lt;span class="n"&gt;beta0_mu&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;beta0_mu&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;beta0_sig&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;InverseGamma&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;beta0_sig&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Group-level intercepts
&lt;/span&gt;    &lt;span class="n"&gt;beta0&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;beta0&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;beta0_mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;beta0_sig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;n_groups&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# ... repeat for each coefficient and for gamma (scale) parameters ...
&lt;/span&gt;
    &lt;span class="n"&gt;mu&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;beta0&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;group_idx&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;beta1&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;group_idx&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;...)&lt;/span&gt;
    &lt;span class="n"&gt;b&lt;/span&gt;  &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;gamma0&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;group_idx&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;gamma1&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;group_idx&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;...)&lt;/span&gt;

    &lt;span class="n"&gt;y_obs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Laplace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;y_obs&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;observed&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Notice the &lt;code&gt;pm.InverseGamma&lt;/code&gt; hyperprior for the variance parameters. The InverseGamma is the conjugate prior for Normal variance, making it a natural choice. With &lt;code&gt;alpha=2, beta=5&lt;/code&gt;, it places mass on moderate variance values while allowing large ones.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Three-Tier Model
&lt;/h3&gt;

&lt;p&gt;The code also contains a three-tier hierarchy. Instead of just grouping by policy type, it nests policy type within region:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;Population → Policy Type → (Region × Policy Type)
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;At the top level, hyper-hyperpriors define the global population. At the middle level, each policy type gets its own parameters drawn from the population. At the bottom level, each (region, policy type) combination gets parameters drawn from its policy type's distribution. The group-level parameters become 2D arrays with shape &lt;code&gt;(n_regions, n_policy_types)&lt;/code&gt;:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Hyper-hyperpriors (population level)
&lt;/span&gt;&lt;span class="n"&gt;beta0_mu_mu&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;beta0_mu_mu&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;beta0_mu_sig&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;InverseGamma&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;beta0_mu_sig&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Hyperpriors (policy type level)
&lt;/span&gt;&lt;span class="n"&gt;beta0_mu&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;beta0_mu&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;beta0_mu_mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;beta0_mu_sig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;n_types&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;beta0_sig&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;InverseGamma&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;beta0_sig&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;beta&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;n_types&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Priors (region × policy type level)
&lt;/span&gt;&lt;span class="n"&gt;beta0&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Normal&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;beta0&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mu&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;beta0_mu&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sigma&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;beta0_sig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_regions&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_types&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This allows a commercial policy in an urban area to differ from one in a suburban area, while both borrow strength from the overall commercial distribution, which itself borrows from the global population.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fkpwi3jds0fdvqi0whtc1.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fkpwi3jds0fdvqi0whtc1.webp" alt="Diagram of the two-tier hierarchical structure: population hyperpriors at the top feeding into policy-type parameters (Auto, Home, Commercial) in the middle, which govern the observed data at the bottom. The wider arrow to Commercial indicates more shrinkage due to its smaller sample size." width="800" height="446"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3&gt;
  
  
  When Not to Use Hierarchical Models
&lt;/h3&gt;

&lt;p&gt;Hierarchical models aren't always necessary. If every group has plenty of data (thousands of observations), no pooling gives nearly identical results to partial pooling because the data overwhelms the prior. The hierarchy adds complexity and sampling time for little benefit.&lt;/p&gt;

&lt;p&gt;They can also struggle with very few groups. With only 2 groups, the hyperprior variance &lt;code&gt;$\sigma_\alpha$&lt;/code&gt; is estimated from just 2 data points (the two group-level parameters), making it unreliable. Most practitioners suggest hierarchical models shine with 5 or more groups, though the exact threshold depends on within-group sample sizes.&lt;/p&gt;

&lt;h2&gt;
  
  
  Where This Comes From
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Lindley and Smith (1972)
&lt;/h3&gt;

&lt;p&gt;The mathematical foundation was laid by Dennis Lindley and Adrian Smith in their 1972 paper "Bayes Estimates for the Linear Model." They formalised the multi-level Normal model:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Cmathbf%257By%257D%2520%255Cmid%2520%255Cboldsymbol%257B%255Ctheta%257D%2520%255Csim%2520%255Cmathcal%257BN%257D%28A%255Cboldsymbol%257B%255Ctheta%257D%252C%255C%252C%2520C%29%2520%255Cqquad%2520%255Cboldsymbol%257B%255Ctheta%257D%2520%255Cmid%2520%255Cboldsymbol%257B%255Cmu%257D%2520%255Csim%2520%255Cmathcal%257BN%257D%28B%255Cboldsymbol%257B%255Cmu%257D%252C%255C%252C%2520D%29%2520%255Cqquad%2520%255Cboldsymbol%257B%255Cmu%257D%2520%255Csim%2520%255Cmathcal%257BN%257D%28%255Cmathbf%257B0%257D%252C%255C%252C%2520E%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Cmathbf%257By%257D%2520%255Cmid%2520%255Cboldsymbol%257B%255Ctheta%257D%2520%255Csim%2520%255Cmathcal%257BN%257D%28A%255Cboldsymbol%257B%255Ctheta%257D%252C%255C%252C%2520C%29%2520%255Cqquad%2520%255Cboldsymbol%257B%255Ctheta%257D%2520%255Cmid%2520%255Cboldsymbol%257B%255Cmu%257D%2520%255Csim%2520%255Cmathcal%257BN%257D%28B%255Cboldsymbol%257B%255Cmu%257D%252C%255C%252C%2520D%29%2520%255Cqquad%2520%255Cboldsymbol%257B%255Cmu%257D%2520%255Csim%2520%255Cmathcal%257BN%257D%28%255Cmathbf%257B0%257D%252C%255C%252C%2520E%29" alt="equation" width="636" height="26"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The key result: the posterior mean of &lt;code&gt;$\boldsymbol{\theta}$&lt;/code&gt; is a &lt;strong&gt;matrix-weighted average&lt;/strong&gt; of the group-specific &lt;a href="https://sesen.ai/blog/maximum-likelihood-estimation-from-scratch" rel="noopener noreferrer"&gt;MLE&lt;/a&gt; and the prior mean. Groups with more data (higher precision in &lt;code&gt;$C^{-1}$&lt;/code&gt;) weight their own MLE more heavily; groups with less data lean more on the prior. This is the formal statement of shrinkage.&lt;/p&gt;

&lt;h3&gt;
  
  
  Efron and Morris (1977): The James-Stein Connection
&lt;/h3&gt;

&lt;p&gt;The frequentist justification for shrinkage came from an unexpected direction. In 1977, Brad Efron and Carl Morris showed that the James-Stein estimator (which shrinks group means toward the grand mean) &lt;strong&gt;dominates&lt;/strong&gt; the usual sample means in terms of total squared error, for three or more groups simultaneously. This was a shocking result: even if the groups have nothing in common, shrinking toward their average reduces total estimation error.&lt;/p&gt;

&lt;blockquote&gt;
&lt;p&gt;"The James-Stein estimator achieves a smaller total mean squared error than the individual sample means, for any configuration of the true means, provided there are three or more groups."&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;The hierarchical Bayesian model produces estimates that are closely related to the James-Stein estimator. The Bayesian framework provides a natural explanation: when data is scarce, it's rational to hedge toward the population average rather than fully committing to a noisy local estimate.&lt;/p&gt;

&lt;h3&gt;
  
  
  Gelman and Hill (2006)
&lt;/h3&gt;

&lt;p&gt;The practical handbook for hierarchical models is Andrew Gelman and Jennifer Hill's &lt;em&gt;Data Analysis Using Regression and Multilevel/Hierarchical Models&lt;/em&gt;. Chapter 12 presents the exact three-model comparison we built above (complete pooling, no pooling, partial pooling) using radon measurements across US counties. Their formulation uses the non-centred parameterisation:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Calpha_j%2520%253D%2520%255Cmu_%255Calpha%2520%252B%2520%255Csigma_%255Calpha%2520%255Ccdot%2520%255Ceta_j%252C%2520%255Cquad%2520%255Ceta_j%2520%255Csim%2520%255Cmathcal%257BN%257D%280%252C%25201%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Calpha_j%2520%253D%2520%255Cmu_%255Calpha%2520%252B%2520%255Csigma_%255Calpha%2520%255Ccdot%2520%255Ceta_j%252C%2520%255Cquad%2520%255Ceta_j%2520%255Csim%2520%255Cmathcal%257BN%257D%280%252C%25201%29" alt="equation" width="346" height="28"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;This reparameterisation often improves MCMC sampling efficiency because the sampler explores a standard Normal geometry rather than a funnel-shaped one. PyMC can apply this transformation automatically, but it's worth knowing when your model has divergences.&lt;/p&gt;

&lt;p&gt;Gelman et al.'s &lt;em&gt;Bayesian Data Analysis&lt;/em&gt; (3rd edition, 2013) provides the full mathematical treatment in Chapter 5, including the relationship between hierarchical Bayes, empirical Bayes, and the James-Stein estimator.&lt;/p&gt;

&lt;h3&gt;
  
  
  Further Reading
&lt;/h3&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;The original formalism:&lt;/strong&gt; Lindley, D. V. &amp;amp; Smith, A. F. M. (1972). "Bayes estimates for the linear model." &lt;em&gt;Journal of the Royal Statistical Society: Series B&lt;/em&gt;, 34(1), 1-41.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;The James-Stein connection:&lt;/strong&gt; Efron, B. &amp;amp; Morris, C. (1977). "Stein's paradox in statistics." &lt;em&gt;Scientific American&lt;/em&gt;, 236(5), 119-127.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;The practical handbook:&lt;/strong&gt; Gelman, A. &amp;amp; Hill, J. (2006). &lt;em&gt;Data Analysis Using Regression and Multilevel/Hierarchical Models&lt;/em&gt;. Cambridge University Press. Chapters 11-13.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;The full Bayesian treatment:&lt;/strong&gt; Gelman, A. et al. (2013). &lt;em&gt;Bayesian Data Analysis&lt;/em&gt;, 3rd ed. CRC Press. Chapter 5.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;PyMC documentation:&lt;/strong&gt; &lt;a href="https://www.pymc.io/projects/examples/en/latest/generalized_linear_models/GLM-hierarchical.html" rel="noopener noreferrer"&gt;PyMC Hierarchical Models tutorial&lt;/a&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Next post in this series:&lt;/strong&gt; Bayesian Survival Analysis, where we extend PyMC to handle censored data using &lt;code&gt;pm.Potential&lt;/code&gt;.&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Interactive Tools
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/bayes-theorem-calculator" rel="noopener noreferrer"&gt;Bayes' Theorem Calculator&lt;/a&gt; — Explore Bayesian updating interactively before diving into hierarchical models&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/ab-test-calculator" rel="noopener noreferrer"&gt;A/B Test Calculator&lt;/a&gt; — See Bayesian hypothesis testing in action, a common application of hierarchical models&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Related Posts
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/from-mle-to-bayesian-inference" rel="noopener noreferrer"&gt;From MLE to Bayesian Inference&lt;/a&gt;: The conceptual foundation for priors, posteriors, and why Bayesian estimates beat point estimates.&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/mcmc-metropolis-hastings-island-hopping-guide" rel="noopener noreferrer"&gt;MCMC Island Hopping: Understanding Metropolis-Hastings&lt;/a&gt;: How the sampler that powers PyMC actually explores the posterior distribution.&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/linear-regression-five-ways" rel="noopener noreferrer"&gt;Linear Regression: Five Ways&lt;/a&gt;: The non-hierarchical regression baseline that this post extends with group structure.&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Frequently Asked Questions
&lt;/h2&gt;

&lt;h3&gt;
  
  
  What is hierarchical Bayesian regression?
&lt;/h3&gt;

&lt;p&gt;Hierarchical (or multilevel) regression models data that is naturally grouped (students within schools, patients within hospitals) by allowing parameters to vary across groups while sharing a common prior distribution. This "partial pooling" approach borrows strength across groups, producing better estimates for small groups than fitting each group independently.&lt;/p&gt;

&lt;h3&gt;
  
  
  What is the difference between complete pooling, no pooling, and partial pooling?
&lt;/h3&gt;

&lt;p&gt;Complete pooling ignores group differences entirely (one model for all). No pooling fits a separate model per group (no information sharing). Partial pooling (hierarchical) sits in between: each group gets its own parameters, but they are pulled towards a shared distribution. This is especially valuable when some groups have very few observations.&lt;/p&gt;

&lt;h3&gt;
  
  
  Why use PyMC for hierarchical models?
&lt;/h3&gt;

&lt;p&gt;PyMC uses MCMC sampling to handle the complex posterior distributions that hierarchical models produce. It naturally propagates uncertainty through all levels of the hierarchy. Frequentist alternatives (like lme4 in R) can fit similar models but do not provide the same rich uncertainty quantification or flexibility for custom model structures.&lt;/p&gt;

&lt;h3&gt;
  
  
  How do I diagnose convergence in PyMC?
&lt;/h3&gt;

&lt;p&gt;Check the trace plots for good mixing (no trends, no stuck chains), verify that R-hat values are close to 1.0 (below 1.01), and ensure effective sample sizes are sufficiently large (at least 400 per chain). Divergent transitions indicate the sampler is struggling with the posterior geometry and may require reparameterisation.&lt;/p&gt;

&lt;h3&gt;
  
  
  When should I use a hierarchical model instead of a standard regression?
&lt;/h3&gt;

&lt;p&gt;Use hierarchical models whenever your data has a natural grouping structure and you want to make inferences about individual groups. They are especially valuable when group sizes are unequal: small groups benefit from borrowing strength, and large groups are barely affected by the pooling. If all groups have abundant data, the results will be similar to fitting separate models.&lt;/p&gt;

</description>
      <category>bayesian</category>
      <category>probabilistic</category>
      <category>inference</category>
      <category>pymc</category>
    </item>
    <item>
      <title>Solving CartPole Without Gradients: Simulated Annealing</title>
      <dc:creator>Berkan Sesen</dc:creator>
      <pubDate>Thu, 23 Apr 2026 07:51:02 +0000</pubDate>
      <link>https://dev.to/berkan_sesen/solving-cartpole-without-gradients-simulated-annealing-3e47</link>
      <guid>https://dev.to/berkan_sesen/solving-cartpole-without-gradients-simulated-annealing-3e47</guid>
      <description>&lt;p&gt;In the &lt;a href="https://sesen.ai/blog/cross-entropy-method-evolution-style-rl" rel="noopener noreferrer"&gt;previous post&lt;/a&gt;, we solved CartPole using the Cross-Entropy Method: sample 200 candidate policies, keep the best 40, refit a Gaussian, repeat. It worked beautifully, reaching a perfect score of 500 in 50 iterations. But 200 candidates per iteration means 10,000 total episode evaluations. That got me wondering: do we really need a population of 200 to find four good numbers?&lt;/p&gt;

&lt;p&gt;The original code that inspired this post took a radically simpler approach. Instead of maintaining a population, it kept a single set of parameters and perturbed them once per iteration. If the perturbation improved the score, it was accepted and the perturbation range was shrunk. That's it. No population, no distribution fitting, no gradients. The comment in the source file read: "its like simulated annealing." By the end of this post, you'll implement this algorithm from scratch, solve CartPole-v1 with a perfect 500 score, and understand how it connects to the rich theory of simulated annealing.&lt;/p&gt;

&lt;h2&gt;
  
  
  Let's Build It
&lt;/h2&gt;

&lt;p&gt;Click the badge to open the interactive notebook:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://colab.research.google.com/github/zhubarb/sesen_ai_ml_tutorials/blob/main/notebooks/reinforcement-learning/simulated_annealing_cartpole.ipynb" rel="noopener noreferrer"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fcolab.research.google.com%2Fassets%2Fcolab-badge.svg" alt="Open In Colab" width="117" height="20"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fsrnztpxlba6spbfbo3t5.gif" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fsrnztpxlba6spbfbo3t5.gif" alt="Simulated annealing convergence animation: best score climbs from ~10 to 500 by iteration 41, then holds steady" width="800" height="400"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Here's the complete implementation. Like CEM, we use a linear policy with 4 parameters (one per observation dimension). But instead of sampling a population, we perturb a single solution:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;gymnasium&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;evaluate_policy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env_name&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;theta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_episodes&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;Run multiple episodes with a linear policy and return the average reward.&lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
    &lt;span class="n"&gt;total_reward&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_episodes&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;env&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;make&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env_name&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;reset&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;episode_reward&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
        &lt;span class="n"&gt;done&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;False&lt;/span&gt;
        &lt;span class="k"&gt;while&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="n"&gt;done&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="n"&gt;action&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;dot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;theta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;&amp;gt;&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
            &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="n"&gt;episode_reward&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;
            &lt;span class="n"&gt;done&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt; &lt;span class="ow"&gt;or&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;
        &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;close&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;total_reward&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;episode_reward&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;total_reward&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;n_episodes&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;simulated_annealing&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env_name&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_params&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_iter&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;80&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_eval_episodes&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                        &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;decay&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.9&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;Hill climbing with annealing step size for policy search.&lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
    &lt;span class="n"&gt;best_theta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_params&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;best_score&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;evaluate_policy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env_name&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;best_theta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_eval_episodes&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_iter&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="c1"&gt;# Perturb current best (uniform noise scaled by alpha)
&lt;/span&gt;        &lt;span class="n"&gt;perturbation&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;rand&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_params&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;
        &lt;span class="n"&gt;candidate&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;best_theta&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;perturbation&lt;/span&gt;

        &lt;span class="c1"&gt;# Evaluate candidate over multiple episodes
&lt;/span&gt;        &lt;span class="n"&gt;score&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;evaluate_policy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env_name&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;candidate&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_eval_episodes&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="c1"&gt;# Accept only if better, then shrink step size
&lt;/span&gt;        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;score&lt;/span&gt; &lt;span class="o"&gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;best_score&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="n"&gt;best_theta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;candidate&lt;/span&gt;
            &lt;span class="n"&gt;best_score&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;score&lt;/span&gt;
            &lt;span class="n"&gt;alpha&lt;/span&gt; &lt;span class="o"&gt;*=&lt;/span&gt; &lt;span class="n"&gt;decay&lt;/span&gt;

        &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Iter &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="n"&gt;d&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; | Score: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;score&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; | Best: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;best_score&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; | Alpha: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;best_theta&lt;/span&gt;

&lt;span class="n"&gt;best_theta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;simulated_annealing&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;CartPole-v1&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_params&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="c1"&gt;# Iter   1 | Score:   9.6 | Best:   9.6 | Alpha: 1.0000
# Iter   9 | Score: 128.7 | Best: 128.7 | Alpha: 0.6561
# Iter  14 | Score: 314.2 | Best: 314.2 | Alpha: 0.5314
# Iter  24 | Score: 465.7 | Best: 465.7 | Alpha: 0.4783
# Iter  41 | Score: 500.0 | Best: 500.0 | Alpha: 0.3874
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Perfect score in 41 iterations. Let's verify with 100 evaluation episodes:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;scores&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="nf"&gt;evaluate_policy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;CartPole-v1&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;best_theta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_episodes&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;)]&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Mean: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;scores&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; +/- &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;std&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;scores&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="c1"&gt;# Mean: 496 +/- 12
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Four parameters, zero gradients, 800 total episode evaluations. Compare that to &lt;a href="https://sesen.ai/blog/cross-entropy-method-evolution-style-rl" rel="noopener noreferrer"&gt;CEM&lt;/a&gt;'s 10,000 episodes or &lt;a href="https://sesen.ai/blog/policy-gradients-reinforce-from-scratch" rel="noopener noreferrer"&gt;REINFORCE&lt;/a&gt;'s 5,000.&lt;/p&gt;

&lt;h2&gt;
  
  
  What Just Happened?
&lt;/h2&gt;

&lt;p&gt;The algorithm maintains a single candidate solution and improves it through a cycle of perturb, evaluate, and accept. Here's the full loop:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fe4tvxovb8mfum7rz4und.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fe4tvxovb8mfum7rz4und.webp" alt="SA algorithm flow: start with zeros, perturb with noise scaled by alpha, evaluate over 10 episodes, accept if better (shrink alpha) or reject (keep current)" width="800" height="446"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Let's walk through each piece.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Linear Policy
&lt;/h3&gt;

&lt;p&gt;Just like in the CEM post, CartPole has a 4-dimensional observation vector (cart position, cart velocity, pole angle, pole angular velocity). Our policy is a simple dot product:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;action&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;dot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;theta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;&amp;gt;&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This is a linear classifier: push right if the weighted sum of observations is positive, push left otherwise. The entire "intelligence" of the agent lives in four numbers.&lt;/p&gt;

&lt;h3&gt;
  
  
  Multi-Episode Evaluation
&lt;/h3&gt;

&lt;p&gt;The original code's key insight (noted in a comment: "key thing was to figure out that you need to do 10 tests per point") is to evaluate each candidate over 10 episodes and average the scores. CartPole has stochastic initial conditions, so a single episode can be misleading. A policy might score 500 on one lucky initialisation and 50 on the next. Averaging over 10 episodes gives a stable estimate of true quality.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;score&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;evaluate_policy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env_name&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;candidate&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_eval_episodes&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  The Perturbation Step
&lt;/h3&gt;

&lt;p&gt;Each iteration, we perturb the current best parameters with uniform noise scaled by &lt;code&gt;alpha&lt;/code&gt;:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;perturbation&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;rand&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_params&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;
&lt;span class="n"&gt;candidate&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;best_theta&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;perturbation&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;When &lt;code&gt;alpha=1.0&lt;/code&gt;, each parameter can change by up to &lt;code&gt;$\pm 0.5$&lt;/code&gt;. As alpha shrinks, the perturbations get smaller, focusing the search around the current best.&lt;/p&gt;

&lt;h3&gt;
  
  
  Accept and Anneal
&lt;/h3&gt;

&lt;p&gt;Here's the crucial part. We only accept improvements, and we only shrink the step size when we find one:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;score&lt;/span&gt; &lt;span class="o"&gt;&amp;gt;&lt;/span&gt; &lt;span class="n"&gt;best_score&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;best_theta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;candidate&lt;/span&gt;
    &lt;span class="n"&gt;best_score&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;score&lt;/span&gt;
    &lt;span class="n"&gt;alpha&lt;/span&gt; &lt;span class="o"&gt;*=&lt;/span&gt; &lt;span class="n"&gt;decay&lt;/span&gt;  &lt;span class="c1"&gt;# Shrink step size by 10%
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This is an adaptive cooling schedule. If the algorithm keeps finding improvements, alpha decays quickly (&lt;code&gt;$0.9^9 \approx 0.39$&lt;/code&gt; after 9 improvements). If it gets stuck, alpha stays large, maintaining exploration. The algorithm found 9 improvements out of 80 iterations, ending with &lt;code&gt;$\alpha = 0.387$&lt;/code&gt;.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Training Curve
&lt;/h3&gt;

&lt;p&gt;The staircase pattern tells the story. Each vertical jump is an accepted improvement; each flat region is the algorithm searching without finding anything better:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;subplots&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;ax1&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;scatter&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;iters&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:],&lt;/span&gt; &lt;span class="n"&gt;candidate_scores&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:],&lt;/span&gt;
            &lt;span class="n"&gt;c&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;#2ecc71&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;a&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;#e74c3c&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;a&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;accepted&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:]],&lt;/span&gt;
            &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;20&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Candidates&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax1&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;iters&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;best_scores&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;b-&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;linewidth&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Best score&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax1&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;axhline&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;500&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;k&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;linestyle&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;:&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Max possible (500)&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;ax2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;ax1&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;twinx&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;ax2&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;iters&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alphas&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;k--&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;linewidth&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Step size (α)&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax2&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_ylabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Step size (α)&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;gray&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fl5cgzdvgyy2y79rw9m2y.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fl5cgzdvgyy2y79rw9m2y.webp" alt="SA training curve showing staircase improvements with candidate scores as coloured dots and step size decay on secondary axis" width="800" height="394"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Green dots are accepted candidates (improvements); red dots are rejected ones. The dashed grey line shows the step size &lt;code&gt;$\alpha$&lt;/code&gt; shrinking on the secondary axis. Notice how the red dots cluster higher as the search progresses, because even rejected perturbations from a good solution tend to produce decent policies.&lt;/p&gt;

&lt;h2&gt;
  
  
  Going Deeper
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Hill Climbing vs True Simulated Annealing
&lt;/h3&gt;

&lt;p&gt;Let's be precise about what our algorithm is. The original code's comment called it "like simulated annealing," and that's accurate, but with an important distinction.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Our algorithm (hill climbing with annealing step size):&lt;/strong&gt;&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;Accepts only improvements&lt;/li&gt;
&lt;li&gt;Shrinks the step size when an improvement is found&lt;/li&gt;
&lt;li&gt;Never accepts a worse solution&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;strong&gt;True simulated annealing:&lt;/strong&gt;&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;Accepts improvements always&lt;/li&gt;
&lt;li&gt;Accepts worse solutions with probability &lt;code&gt;$e^{-\Delta E / T}$&lt;/code&gt;
&lt;/li&gt;
&lt;li&gt;Shrinks the temperature &lt;code&gt;$T$&lt;/code&gt; on a fixed schedule&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;The difference is in how they handle worse solutions. True SA occasionally accepts a downhill move, which allows it to escape local optima. Our algorithm never does, which makes it a strict hill climber. The "annealing" part is only in the step size, not in the acceptance criterion.&lt;/p&gt;

&lt;p&gt;For CartPole with a 4-parameter linear policy, this distinction doesn't matter: the reward landscape is smooth enough that hill climbing works. For harder problems with many local optima, true SA's ability to escape traps becomes essential.&lt;/p&gt;

&lt;p&gt;If you've read the &lt;a href="https://sesen.ai/blog/mcmc-metropolis-hastings-island-hopping-guide" rel="noopener noreferrer"&gt;MCMC Metropolis-Hastings post&lt;/a&gt;, the acceptance criterion should look familiar. The Metropolis acceptance probability &lt;code&gt;$\min(1, e^{-\Delta E / T})$&lt;/code&gt; is exactly what true SA uses. In MCMC, we want to sample from a distribution; in SA, we want to find its peak. Same mechanism, different goal.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Cooling Schedule
&lt;/h3&gt;

&lt;p&gt;Our algorithm uses a multiplicative decay: &lt;code&gt;$\alpha_{t+1} = 0.9 \cdot \alpha_t$&lt;/code&gt; on each improvement. This creates a geometric sequence:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Calpha_k%2520%253D%2520%255Calpha_0%2520%255Ccdot%2520%255Cgamma%255Ek%2520%253D%25201.0%2520%255Ccdot%25200.9%255Ek" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Calpha_k%2520%253D%2520%255Calpha_0%2520%255Ccdot%2520%255Cgamma%255Ek%2520%253D%25201.0%2520%255Ccdot%25200.9%255Ek" alt="equation" width="249" height="28"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;where &lt;code&gt;$k$&lt;/code&gt; is the number of improvements found. After 9 improvements, &lt;code&gt;$\alpha = 0.9^9 \approx 0.387$&lt;/code&gt;.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;ax1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;subplots&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;12&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;ax1&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;iters&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alphas&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;b-&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;linewidth&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;       &lt;span class="c1"&gt;# Alpha vs iterations
&lt;/span&gt;&lt;span class="n"&gt;ax2&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;steps&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;0.9&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="n"&gt;steps&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;linewidth&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;          &lt;span class="c1"&gt;# Geometric decay curves
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fs92kvrqg9qt5eryy4660.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fs92kvrqg9qt5eryy4660.webp" alt="Cooling schedule: left panel shows step size over iterations with green bars marking improvements; right panel compares geometric decay rates" width="800" height="261"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The left panel shows alpha over iterations, with green bands marking accepted improvements. The right panel compares different decay rates. A faster decay (&lt;code&gt;$\gamma = 0.8$&lt;/code&gt;) converges to fine-tuning quickly but risks getting stuck. A slower decay (&lt;code&gt;$\gamma = 0.95$&lt;/code&gt;) explores longer but takes more iterations to refine. The original code's choice of 0.9 strikes a reasonable balance.&lt;/p&gt;

&lt;p&gt;What makes our schedule adaptive is that it only decays on improvement. Traditional SA uses fixed schedules (logarithmic, linear, or exponential decay in wall-clock time). Our variant keeps &lt;code&gt;$\alpha$&lt;/code&gt; large during plateaus, naturally spending more time exploring when stuck and more time refining when making progress.&lt;/p&gt;

&lt;h3&gt;
  
  
  SA vs CEM: One Climber vs a Search Party
&lt;/h3&gt;

&lt;p&gt;The &lt;a href="https://sesen.ai/blog/cross-entropy-method-evolution-style-rl" rel="noopener noreferrer"&gt;Cross-Entropy Method&lt;/a&gt; we built last time and simulated annealing sit at opposite ends of the derivative-free spectrum:&lt;/p&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Aspect&lt;/th&gt;
&lt;th&gt;Simulated Annealing&lt;/th&gt;
&lt;th&gt;CEM&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;Search strategy&lt;/td&gt;
&lt;td&gt;Single point, local perturbations&lt;/td&gt;
&lt;td&gt;Population of 200, distribution fitting&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Episodes per iteration&lt;/td&gt;
&lt;td&gt;10&lt;/td&gt;
&lt;td&gt;200 (200 candidates x 1 each)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Total episodes to solve CartPole&lt;/td&gt;
&lt;td&gt;~800&lt;/td&gt;
&lt;td&gt;~10,000 (200 x 50 iterations)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Information used&lt;/td&gt;
&lt;td&gt;"Is this better than the best?" (1 bit)&lt;/td&gt;
&lt;td&gt;Full reward ranking of all candidates&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Robustness&lt;/td&gt;
&lt;td&gt;Seed-dependent; some runs may fail&lt;/td&gt;
&lt;td&gt;Highly robust; population averages out noise&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Parallelisable&lt;/td&gt;
&lt;td&gt;No (sequential by nature)&lt;/td&gt;
&lt;td&gt;Yes (all 200 evaluations are independent)&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;p&gt;SA is like a single hiker exploring a mountain range, taking one step at a time and only moving to higher ground. CEM is like sending 200 hikers, ranking them by altitude, and teleporting the next batch to the region where the best ones clustered.&lt;/p&gt;

&lt;p&gt;SA wins on sample efficiency (fewer total episodes) but loses on reliability. Run SA with a different random seed and you might need 20 iterations or 200. CEM's population averaging makes it much more consistent.&lt;/p&gt;

&lt;h3&gt;
  
  
  SA vs Random Search
&lt;/h3&gt;

&lt;p&gt;How much does the "annealing" (building on previous improvements) actually help, compared to just sampling random policies each time?&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;subplots&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;iters&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sa_best_scores&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;b-&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;linewidth&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Simulated annealing&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;iters&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_best_scores&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;r--&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;linewidth&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Random search&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Frdd467x63atvwvt9e5kx.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Frdd467x63atvwvt9e5kx.webp" alt="SA reaching 500 while random search plateaus at 387 after 80 iterations" width="800" height="394"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Random search samples a fresh random policy each iteration (uniform in &lt;code&gt;$[-1, 1]^4$&lt;/code&gt;) and tracks the best one found. After 80 iterations, its best score is 387 vs SA's 500. Random search got lucky once (iteration 2) and found a decent policy early, but it can never refine it. SA's ability to make small improvements to an already-good solution is what pushes it from "decent" to "perfect."&lt;/p&gt;

&lt;h3&gt;
  
  
  Hyperparameters
&lt;/h3&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Parameter&lt;/th&gt;
&lt;th&gt;Value&lt;/th&gt;
&lt;th&gt;Effect&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;alpha&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;1.0&lt;/td&gt;
&lt;td&gt;Initial step size. Perturbations range in &lt;code&gt;$[-0.5, 0.5]$&lt;/code&gt; per parameter&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;decay&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;0.9&lt;/td&gt;
&lt;td&gt;Step size multiplier on improvement. Lower = faster convergence, less exploration&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;n_iter&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;80&lt;/td&gt;
&lt;td&gt;Total iterations. Our run converged at iteration 41&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;n_eval_episodes&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;10&lt;/td&gt;
&lt;td&gt;Episodes per evaluation. More = less noise, more compute&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;p&gt;The most sensitive parameter is &lt;code&gt;decay&lt;/code&gt;. At 0.9, alpha halves after about 7 improvements. At 0.8, it halves after 4. Too aggressive and the step size collapses before finding a good solution; too conservative and you waste iterations on large perturbations when you're already close.&lt;/p&gt;

&lt;h3&gt;
  
  
  When NOT to Use This Approach
&lt;/h3&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;High-dimensional parameter spaces.&lt;/strong&gt; A single perturbation in 1000 dimensions is unlikely to improve on the current best by chance. Population methods like &lt;a href="https://sesen.ai/blog/cross-entropy-method-evolution-style-rl" rel="noopener noreferrer"&gt;CEM&lt;/a&gt; or &lt;a href="https://sesen.ai/blog/genetic-algorithms-from-scratch" rel="noopener noreferrer"&gt;genetic algorithms&lt;/a&gt; scale better&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Multi-modal reward landscapes.&lt;/strong&gt; Our hill climber can only find the nearest peak. If the global optimum is separated by a valley, you'll never reach it without true SA's downhill acceptance&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;When you need guarantees.&lt;/strong&gt; SA is a heuristic. Even true SA only guarantees convergence to the global optimum with logarithmic cooling, which is impractically slow&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;When wall-clock time matters more than sample efficiency.&lt;/strong&gt; SA is inherently sequential. CEM's 200 evaluations per iteration can run in parallel, making it faster on multi-core hardware despite using 12x more episodes&lt;/li&gt;
&lt;/ol&gt;

&lt;h2&gt;
  
  
  Where This Comes From
&lt;/h2&gt;

&lt;p&gt;Simulated annealing was introduced independently by &lt;strong&gt;Scott Kirkpatrick, Daniel Gelatt, and Mario Vecchi&lt;/strong&gt; at IBM Research in their 1983 Science paper &lt;a href="https://doi.org/10.1126/science.220.4598.671" rel="noopener noreferrer"&gt;"Optimization by Simulated Annealing"&lt;/a&gt;, and by &lt;strong&gt;Vlasta Cerny&lt;/strong&gt; in 1985. The name comes from the metallurgical process of annealing: heating a metal and then slowly cooling it to reduce defects in its crystal structure.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Metallurgy Analogy
&lt;/h3&gt;

&lt;p&gt;When you heat metal, atoms vibrate wildly and can escape local energy minima. As the temperature drops, atoms settle into increasingly stable configurations. If you cool slowly enough, the metal reaches its lowest-energy crystal state (the global optimum). Cool too fast and you get a brittle, disordered structure (a local optimum).&lt;/p&gt;

&lt;p&gt;Kirkpatrick and colleagues mapped this physical process to combinatorial optimisation:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Metal atoms&lt;/strong&gt; become candidate solutions&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Energy&lt;/strong&gt; becomes the cost function&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Temperature&lt;/strong&gt; becomes a control parameter that governs randomness&lt;/li&gt;
&lt;/ul&gt;

&lt;h3&gt;
  
  
  The Metropolis Connection
&lt;/h3&gt;

&lt;p&gt;The acceptance criterion in true SA comes directly from the &lt;strong&gt;Metropolis algorithm&lt;/strong&gt; (Metropolis, Rosenbluth, Rosenbluth, Teller, and Teller, 1953), originally designed for simulating atomic systems in statistical mechanics. At temperature &lt;code&gt;$T$&lt;/code&gt;, a new state with energy &lt;code&gt;$E'$&lt;/code&gt; is accepted from a current state with energy &lt;code&gt;$E$&lt;/code&gt; with probability:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DP%28%255Ctext%257Baccept%257D%29%2520%253D%2520%255Cbegin%257Bcases%257D%25201%2520%2526%2520%255Ctext%257Bif%2520%257D%2520E%27%2520%253C%2520E%2520%255C%255C%2520e%255E%257B-%28E%27%2520-%2520E%29%2520%252F%2520T%257D%2520%2526%2520%255Ctext%257Bif%2520%257D%2520E%27%2520%255Cgeq%2520E%2520%255Cend%257Bcases%257D" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DP%28%255Ctext%257Baccept%257D%29%2520%253D%2520%255Cbegin%257Bcases%257D%25201%2520%2526%2520%255Ctext%257Bif%2520%257D%2520E%27%2520%253C%2520E%2520%255C%255C%2520e%255E%257B-%28E%27%2520-%2520E%29%2520%252F%2520T%257D%2520%2526%2520%255Ctext%257Bif%2520%257D%2520E%27%2520%255Cgeq%2520E%2520%255Cend%257Bcases%257D" alt="equation" width="390" height="75"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;At high &lt;code&gt;$T$&lt;/code&gt;, the exponential is close to 1, so almost any move is accepted (random exploration). At low &lt;code&gt;$T$&lt;/code&gt;, only improvements or tiny degradations are accepted (local refinement). As &lt;code&gt;$T \to 0$&lt;/code&gt;, the algorithm becomes pure hill climbing.&lt;/p&gt;

&lt;p&gt;This is the same acceptance probability we explored in the &lt;a href="https://sesen.ai/blog/mcmc-metropolis-hastings-island-hopping-guide" rel="noopener noreferrer"&gt;Metropolis-Hastings post&lt;/a&gt; for MCMC sampling. The only difference: in MCMC, we maintain a high temperature to sample broadly; in SA, we lower it to converge on a peak. Same mechanism, different goals.&lt;/p&gt;

&lt;h3&gt;
  
  
  Our Variant vs Classical SA
&lt;/h3&gt;

&lt;p&gt;Our implementation simplifies classical SA in two ways:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;No downhill acceptance.&lt;/strong&gt; We only accept improvements, making our algorithm a strict hill climber. Classical SA would occasionally accept a worse solution, with probability decreasing as the temperature drops&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Adaptive cooling.&lt;/strong&gt; Classical SA uses a fixed cooling schedule (e.g., &lt;code&gt;$T_k = T_0 / \log(1+k)$&lt;/code&gt; for the theoretical guarantee). Our schedule only cools when an improvement is found, which adapts the exploration rate to the difficulty of the landscape&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;Despite these simplifications, our algorithm captures SA's core idea: start with large moves (exploration) and gradually transition to small moves (exploitation). For low-dimensional problems like our 4-parameter CartPole policy, this simplified variant works as well as the full SA.&lt;/p&gt;

&lt;h3&gt;
  
  
  Theoretical Guarantees
&lt;/h3&gt;

&lt;p&gt;Kirkpatrick et al. proved that SA with logarithmic cooling (&lt;code&gt;$T_k = c / \log(1+k)$&lt;/code&gt;) converges to the global optimum in probability. However, this schedule is impractically slow for real problems. In practice, faster geometric schedules (&lt;code&gt;$T_{k+1} = \alpha T_k$&lt;/code&gt;) are used, sacrificing the global optimality guarantee for practical convergence speed.&lt;/p&gt;

&lt;blockquote&gt;
&lt;p&gt;"There is a deep and useful connection between statistical mechanics [...] and multivariate or combinatorial optimization. [...] We have applied this framework to the design of computer hardware, to a specific and practical problem in computer layout."&lt;br&gt;
&lt;em&gt;Kirkpatrick, Gelatt, and Vecchi (1983)&lt;/em&gt;&lt;/p&gt;
&lt;/blockquote&gt;

&lt;h3&gt;
  
  
  Further Reading
&lt;/h3&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://doi.org/10.1126/science.220.4598.671" rel="noopener noreferrer"&gt;Kirkpatrick, Gelatt, and Vecchi (1983)&lt;/a&gt;, "Optimization by Simulated Annealing" - The foundational paper. Read Section II for the algorithm and Section IV for the VLSI application&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://doi.org/10.1063/1.1699114" rel="noopener noreferrer"&gt;Metropolis et al. (1953)&lt;/a&gt;, "Equation of State Calculations by Fast Computing Machines" - The acceptance criterion used by SA, originally for molecular simulation&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://doi.org/10.1007/BF00940812" rel="noopener noreferrer"&gt;Cerny (1985)&lt;/a&gt;, "Thermodynamical Approach to the Traveling Salesman Problem" - Independent invention of SA for TSP&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Sutton and Barto (2018)&lt;/strong&gt;, Ch. 1 - Context for derivative-free methods in the RL landscape&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://doi.org/10.1007/978-1-4757-4321-0" rel="noopener noreferrer"&gt;Rubinstein and Kroese (2004)&lt;/a&gt;, &lt;em&gt;The Cross-Entropy Method&lt;/em&gt; - For comparison with the population-based approach&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Try It Yourself
&lt;/h2&gt;

&lt;p&gt;The &lt;a href="https://colab.research.google.com/github/zhubarb/sesen_ai_ml_tutorials/blob/main/notebooks/reinforcement-learning/simulated_annealing_cartpole.ipynb" rel="noopener noreferrer"&gt;interactive notebook&lt;/a&gt; includes exercises:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Decay rate sweep&lt;/strong&gt;: Try &lt;code&gt;decay&lt;/code&gt; values of 0.8, 0.9, 0.95, and 0.99. How does the cooling speed affect convergence? Is there a sweet spot?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;True simulated annealing&lt;/strong&gt;: Modify the algorithm to accept worse solutions with probability &lt;code&gt;$e^{-\Delta / T}$&lt;/code&gt; where &lt;code&gt;$\Delta$&lt;/code&gt; is the score difference and &lt;code&gt;$T$&lt;/code&gt; decays on a fixed schedule. Does it help on CartPole? When would it matter?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Seed sensitivity&lt;/strong&gt;: Run the algorithm 20 times with different random seeds. What fraction of runs reach 500? How does this compare to CEM's reliability?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Harder environments&lt;/strong&gt;: Try SA on &lt;code&gt;Acrobot-v1&lt;/code&gt; or &lt;code&gt;MountainCar-v0&lt;/code&gt;. Does the 4-parameter linear policy have enough capacity, or do these environments need a richer policy class?&lt;/li&gt;
&lt;/ol&gt;

&lt;h2&gt;
  
  
  Interactive Tools
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/q-learning-visualizer" rel="noopener noreferrer"&gt;Q-Learning Visualiser&lt;/a&gt; — Compare SA's derivative-free approach with value-based RL on grid worlds&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Related Posts
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/cross-entropy-method-evolution-style-rl" rel="noopener noreferrer"&gt;The Cross-Entropy Method: Solving RL Without Gradients&lt;/a&gt; - The population-based companion to SA. Both are derivative-free, but CEM trades sample efficiency for robustness by maintaining 200 candidates per iteration.&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/mcmc-metropolis-hastings-island-hopping-guide" rel="noopener noreferrer"&gt;MCMC Island Hopping: Understanding Metropolis-Hastings&lt;/a&gt; - The acceptance criterion that powers true SA comes directly from the Metropolis algorithm. In MCMC we sample from a distribution; in SA we find its peak.&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/genetic-algorithms-from-scratch" rel="noopener noreferrer"&gt;Genetic Algorithms: From Line Fitting to the Travelling Salesman&lt;/a&gt; - Another derivative-free optimisation family. GAs use crossover and mutation on a population; SA uses perturbation on a single solution.&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Frequently Asked Questions
&lt;/h2&gt;

&lt;h3&gt;
  
  
  How is simulated annealing different from random search?
&lt;/h3&gt;

&lt;p&gt;Random search samples a completely new policy each iteration and tracks the best one found, but it can never refine a promising solution. Simulated annealing builds on previous improvements by perturbing the current best parameters with decreasing noise. This ability to make small refinements to an already-good solution is what pushes SA from "decent" to "perfect" on CartPole.&lt;/p&gt;

&lt;h3&gt;
  
  
  Why does the algorithm evaluate each candidate over 10 episodes instead of 1?
&lt;/h3&gt;

&lt;p&gt;CartPole has stochastic initial conditions, so a single episode can be misleading. A policy might score 500 on one lucky initialisation and 50 on the next. Averaging over 10 episodes gives a stable estimate of true quality, preventing the algorithm from accepting a lucky fluke or rejecting a good policy due to bad luck.&lt;/p&gt;

&lt;h3&gt;
  
  
  Is this true simulated annealing?
&lt;/h3&gt;

&lt;p&gt;Not quite. True simulated annealing occasionally accepts worse solutions with a probability that decreases over time, allowing it to escape local optima. Our implementation is a strict hill climber that only accepts improvements. The "annealing" part refers only to the shrinking step size. For CartPole's smooth 4-parameter landscape, this distinction does not matter, but for problems with many local optima, true SA's downhill acceptance becomes essential.&lt;/p&gt;

&lt;h3&gt;
  
  
  Why does the step size only shrink when an improvement is found?
&lt;/h3&gt;

&lt;p&gt;This creates an adaptive cooling schedule. If the algorithm keeps finding improvements, the step size decays quickly, focusing the search around the current best. If it gets stuck in a plateau, the step size stays large, maintaining broad exploration. This naturally spends more time exploring when stuck and more time refining when making progress.&lt;/p&gt;

&lt;h3&gt;
  
  
  When would simulated annealing fail compared to population-based methods?
&lt;/h3&gt;

&lt;p&gt;SA struggles in high-dimensional parameter spaces where a single random perturbation is unlikely to improve all parameters at once. It also fails on multi-modal reward landscapes because, as a strict hill climber, it can only find the nearest peak. Population-based methods like the Cross-Entropy Method or genetic algorithms handle both cases better by maintaining diversity across many candidates simultaneously.&lt;/p&gt;

</description>
      <category>reinforcementlearning</category>
      <category>optimisation</category>
    </item>
    <item>
      <title>The Cross-Entropy Method: Solving RL Without Gradients</title>
      <dc:creator>Berkan Sesen</dc:creator>
      <pubDate>Tue, 21 Apr 2026 08:27:46 +0000</pubDate>
      <link>https://dev.to/berkan_sesen/the-cross-entropy-method-solving-rl-without-gradients-1lol</link>
      <guid>https://dev.to/berkan_sesen/the-cross-entropy-method-solving-rl-without-gradients-1lol</guid>
      <description>&lt;p&gt;Reinforcement learning has accumulated layers of complexity over the years: value functions, policy gradients, replay buffers, target networks. The Cross-Entropy Method predates all of it. Rubinstein introduced it in 1997 for rare-event simulation, and it turned out to solve simple control tasks with almost no machinery. The entire implementation fits in 50 lines. No gradients, no training loops. Just: sample some parameters, test them, keep the best ones, repeat.&lt;/p&gt;

&lt;p&gt;The Cross-Entropy Method (CEM) is the algorithm you reach for when you want results without complexity. It treats the policy's parameters as a black box, maintains a probability distribution over them, and iteratively narrows that distribution toward high-performing regions. No gradients required. By the end of this post, you'll implement CEM from scratch, solve CartPole-v1 with a perfect score, and understand why this "naive" approach works so well on problems with manageable parameter spaces.&lt;/p&gt;

&lt;h2&gt;
  
  
  Let's Build It
&lt;/h2&gt;

&lt;p&gt;Click the badge to open the interactive notebook:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://colab.research.google.com/github/zhubarb/sesen_ai_ml_tutorials/blob/main/notebooks/reinforcement-learning/cross_entropy_method.ipynb" rel="noopener noreferrer"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fcolab.research.google.com%2Fassets%2Fcolab-badge.svg" alt="Open In Colab" width="117" height="20"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fh9lzj24j8fzzyzxjhagc.gif" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fh9lzj24j8fzzyzxjhagc.gif" alt="CEM convergence animation showing the reward distribution shifting from low to high over 50 iterations" width="800" height="500"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Here's the complete implementation. We use a linear policy with just 4 parameters (one per observation dimension), and CEM finds the perfect weights:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;gymnasium&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;evaluate_policy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env_name&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;theta&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;Run one episode with a linear policy: action = 1 if theta @ obs &amp;gt; 0 else 0.&lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
    &lt;span class="n"&gt;env&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;gym&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;make&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env_name&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;reset&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="n"&gt;total_reward&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
    &lt;span class="n"&gt;done&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;False&lt;/span&gt;
    &lt;span class="k"&gt;while&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="n"&gt;done&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="n"&gt;action&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;dot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;theta&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;&amp;gt;&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
        &lt;span class="n"&gt;obs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;action&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;total_reward&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;reward&lt;/span&gt;
        &lt;span class="n"&gt;done&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;terminated&lt;/span&gt; &lt;span class="ow"&gt;or&lt;/span&gt; &lt;span class="n"&gt;truncated&lt;/span&gt;
    &lt;span class="n"&gt;env&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;close&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;total_reward&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;cem&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env_name&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_params&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;200&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_iter&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;50&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;elite_frac&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;initial_std&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;extra_std&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;std_decay_time&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;25&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;Cross-Entropy Method for policy search.&lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
    &lt;span class="n"&gt;n_elite&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;int&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;round&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;elite_frac&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
    &lt;span class="n"&gt;th_mean&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_params&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;th_std&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;ones&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_params&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;initial_std&lt;/span&gt;

    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;iteration&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_iter&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="c1"&gt;# Decaying extra noise (Szita &amp;amp; Lörincz 2006)
&lt;/span&gt;        &lt;span class="n"&gt;noise_multiplier&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;iteration&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="nf"&gt;float&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;std_decay_time&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;sample_std&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;th_std&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;square&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;extra_std&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;noise_multiplier&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="c1"&gt;# Sample and evaluate
&lt;/span&gt;        &lt;span class="n"&gt;thetas&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;th_mean&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;sample_std&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_params&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;rewards&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;array&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="nf"&gt;evaluate_policy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env_name&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;th&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;th&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;thetas&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;

        &lt;span class="c1"&gt;# Select elite and refit distribution
&lt;/span&gt;        &lt;span class="n"&gt;elite_inds&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;rewards&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;argsort&lt;/span&gt;&lt;span class="p"&gt;()[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;n_elite&lt;/span&gt;&lt;span class="p"&gt;:]&lt;/span&gt;
        &lt;span class="n"&gt;elite_thetas&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;thetas&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;elite_inds&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
        &lt;span class="n"&gt;th_mean&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;elite_thetas&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;th_std&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;elite_thetas&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;var&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Iter &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;iteration&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="n"&gt;d&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; | Mean: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;rewards&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="mf"&gt;6.1&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; | Max: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;rewards&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;max&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;th_mean&lt;/span&gt;

&lt;span class="n"&gt;best_theta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;cem&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;CartPole-v1&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_params&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="c1"&gt;# Iter   1 | Mean:   66.8 | Max: 500
# Iter  10 | Mean:  384.0 | Max: 500
# Iter  30 | Mean:  495.2 | Max: 500
# Iter  50 | Mean:  499.1 | Max: 500
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The population mean reward climbs from 67 to 499 in 50 iterations. Every single sample in the final batch scores near-perfect. Let's verify with 100 evaluation episodes:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;scores&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="nf"&gt;evaluate_policy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;CartPole-v1&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;best_theta&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;)]&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Mean: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;scores&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; ± &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;std&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;scores&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="c1"&gt;# Mean: 500 ± 0
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Perfect score. Four parameters, zero gradients, 50 iterations.&lt;/p&gt;

&lt;h2&gt;
  
  
  What Just Happened?
&lt;/h2&gt;

&lt;p&gt;CEM works by maintaining a Gaussian distribution over policy parameters and repeatedly narrowing it toward the best-performing region. Each iteration has three steps:&lt;/p&gt;

&lt;h3&gt;
  
  
  Step 1: Sample
&lt;/h3&gt;

&lt;p&gt;We draw &lt;code&gt;batch_size=200&lt;/code&gt; parameter vectors from a Gaussian:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;thetas&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;th_mean&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;sample_std&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_params&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Each &lt;code&gt;theta&lt;/code&gt; is a candidate policy. In iteration 1, the mean is zeros and the standard deviation is 1.0, so we're sampling random policies.&lt;/p&gt;

&lt;h3&gt;
  
  
  Step 2: Evaluate and Select
&lt;/h3&gt;

&lt;p&gt;We run each candidate policy on CartPole and rank them by total reward. Then we keep only the top 20% (the "elite" set):&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;rewards&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;array&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="nf"&gt;evaluate_policy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;env_name&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;th&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;th&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;thetas&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;span class="n"&gt;elite_inds&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;rewards&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;argsort&lt;/span&gt;&lt;span class="p"&gt;()[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;n_elite&lt;/span&gt;&lt;span class="p"&gt;:]&lt;/span&gt;  &lt;span class="c1"&gt;# Top 40 out of 200
&lt;/span&gt;&lt;span class="n"&gt;elite_thetas&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;thetas&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;elite_inds&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  Step 3: Refit the Distribution
&lt;/h3&gt;

&lt;p&gt;We refit the Gaussian to match the elite samples:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;th_mean&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;elite_thetas&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;th_std&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;elite_thetas&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;var&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The new mean moves toward parameters that performed well. The new variance shrinks because the elite samples cluster together. Next iteration, we sample from this tighter distribution, generating better candidates on average.&lt;/p&gt;

&lt;h3&gt;
  
  
  Watching It Converge
&lt;/h3&gt;

&lt;p&gt;The training curve shows how the population improves:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;subplots&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;iters&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mean_rewards&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;b-&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.7&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Population mean&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;iters&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;elite_mean_rewards&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;r-&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;linewidth&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Elite mean&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;iters&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_rewards&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;g--&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Best in batch&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;axhline&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;500&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;k&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;linestyle&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;:&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Max possible (500)&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_xlabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Iteration&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_ylabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Total Reward&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;legend&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fjsrtq0yzp11jtcxexvvj.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fjsrtq0yzp11jtcxexvvj.webp" alt="CEM training curve on CartPole-v1 showing population mean climbing from 67 to 500 over 50 iterations" width="800" height="450"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The elite mean hits 500 almost immediately (iteration 2). But the population mean takes longer to catch up because the distribution is still wide. By iteration 30, even randomly sampled policies from the learned distribution score near-perfect.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Distribution Narrows Over Time
&lt;/h3&gt;

&lt;p&gt;To see this visually, here's how the reward distribution across the 200 samples evolves:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axes&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;subplots&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;12&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;iteration_rewards&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;title&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;zip&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;axes&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;selected_iterations&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;titles&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;hist&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;iteration_rewards&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;bins&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;20&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;steelblue&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;edgecolor&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;white&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;axvline&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;iteration_rewards&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;red&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;linestyle&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;--&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F7n9zgznzl5aqo9ix7wa5.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F7n9zgznzl5aqo9ix7wa5.webp" alt="Reward distributions at iterations 1, 10, and 50 showing the population concentrating at 500" width="800" height="261"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;In iteration 1, most policies fail quickly (reward &amp;lt; 100) with a few lucky ones reaching 500. By iteration 10, the distribution is bimodal: many policies near 500 but some still struggling. By iteration 50, the entire population clusters at 500. The distribution has collapsed onto the solution.&lt;/p&gt;

&lt;h2&gt;
  
  
  Going Deeper
&lt;/h2&gt;

&lt;h3&gt;
  
  
  The Noisy Cross-Entropy Method
&lt;/h3&gt;

&lt;p&gt;The original CEM (Rubinstein 1999) has a failure mode: the variance can collapse to zero too quickly, trapping the search in a local optimum. Szita and Lörincz (2006) fixed this with the "noisy" variant that adds decaying extra variance:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Csigma_%257Bt%252B1%257D%255E2%2520%253D%2520%255Csigma_%257Bt%252C%255Ctext%257Belite%257D%257D%255E2%2520%252B%2520Z_t%255E2%2520%255Ccdot%2520%255Csigma_%257B%255Ctext%257Bextra%257D%257D%255E2" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Csigma_%257Bt%252B1%257D%255E2%2520%253D%2520%255Csigma_%257Bt%252C%255Ctext%257Belite%257D%257D%255E2%2520%252B%2520Z_t%255E2%2520%255Ccdot%2520%255Csigma_%257B%255Ctext%257Bextra%257D%257D%255E2" alt="equation" width="265" height="31"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;where &lt;code&gt;$Z_t = \max(1 - t / T_{\text{decay}},\; 0)$&lt;/code&gt; decays linearly to zero. Early iterations get extra exploration; later iterations trust the elite variance.&lt;/p&gt;

&lt;p&gt;This is exactly what our code does:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;noise_multiplier&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;iteration&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="nf"&gt;float&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;std_decay_time&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;sample_std&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;th_std&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;square&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;extra_std&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;noise_multiplier&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The &lt;code&gt;extra_std=0.5&lt;/code&gt; decays over &lt;code&gt;std_decay_time=25&lt;/code&gt; iterations. After iteration 25, the sampling distribution uses only the elite variance.&lt;/p&gt;

&lt;h3&gt;
  
  
  Hyperparameters
&lt;/h3&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Parameter&lt;/th&gt;
&lt;th&gt;Value&lt;/th&gt;
&lt;th&gt;Effect&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;batch_size&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;200&lt;/td&gt;
&lt;td&gt;More samples = better coverage but slower per iteration&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;elite_frac&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;0.2&lt;/td&gt;
&lt;td&gt;Lower = more selective, faster convergence, risk of premature collapse&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;initial_std&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;1.0&lt;/td&gt;
&lt;td&gt;Too low = miss good regions; too high = waste samples on extreme policies&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;extra_std&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;0.5&lt;/td&gt;
&lt;td&gt;Noise injection; 0 = original CEM, &amp;gt;0 = noisy CEM&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;std_decay_time&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;25&lt;/td&gt;
&lt;td&gt;How many iterations before extra noise disappears&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;p&gt;The most sensitive parameter is &lt;code&gt;elite_frac&lt;/code&gt;. At 0.2 (keep top 40 of 200), we balance exploitation and exploration. Setting it to 0.01 (keep top 2) would converge faster in easy environments but collapse in hard ones.&lt;/p&gt;

&lt;h3&gt;
  
  
  CEM vs Random Search
&lt;/h3&gt;

&lt;p&gt;Both CEM and random search sample 200 policies per iteration. The difference: random search starts fresh every time, while CEM builds on what worked:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;subplots&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;iters&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cem_mean_rewards&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;b-&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;linewidth&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;CEM (population mean)&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;iters&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_mean_rewards&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;r--&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;linewidth&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Random search (mean)&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fkw1tl8s1h1urkxx3ka5d.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fkw1tl8s1h1urkxx3ka5d.webp" alt="CEM population mean climbing to 500 while random search stays flat at ~60" width="800" height="450"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Random search averages about 60 reward per iteration, forever. CEM reaches 500 because each iteration's distribution is informed by the last. The "select and refit" loop creates a directed search through parameter space.&lt;/p&gt;

&lt;h3&gt;
  
  
  CEM vs Policy Gradients vs DQN
&lt;/h3&gt;

&lt;p&gt;How does CEM compare to the gradient-based methods we've covered?&lt;/p&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Method&lt;/th&gt;
&lt;th&gt;What it optimises&lt;/th&gt;
&lt;th&gt;Needs gradients?&lt;/th&gt;
&lt;th&gt;Scales to large nets?&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;CEM&lt;/td&gt;
&lt;td&gt;Policy parameters directly&lt;/td&gt;
&lt;td&gt;No&lt;/td&gt;
&lt;td&gt;Poorly (&amp;gt;1000 params)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;a href="https://sesen.ai/blog/policy-gradients-reinforce-from-scratch" rel="noopener noreferrer"&gt;REINFORCE&lt;/a&gt;&lt;/td&gt;
&lt;td&gt;Policy parameters via log-prob gradient&lt;/td&gt;
&lt;td&gt;Yes&lt;/td&gt;
&lt;td&gt;Yes&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;a href="https://sesen.ai/blog/deep-q-networks-experience-replay-target-networks" rel="noopener noreferrer"&gt;DQN&lt;/a&gt;&lt;/td&gt;
&lt;td&gt;Value function (Q-values)&lt;/td&gt;
&lt;td&gt;Yes&lt;/td&gt;
&lt;td&gt;Yes&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;a href="https://sesen.ai/blog/q-learning-frozen-lake-from-scratch" rel="noopener noreferrer"&gt;Q-Learning&lt;/a&gt;&lt;/td&gt;
&lt;td&gt;Value function (Q-table)&lt;/td&gt;
&lt;td&gt;No&lt;/td&gt;
&lt;td&gt;No (tabular only)&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;p&gt;CEM's sweet spot: &lt;strong&gt;problems with fewer than ~1000 parameters&lt;/strong&gt; where you want a simple, parallelisable algorithm. For a 4-parameter linear policy on CartPole, CEM is hard to beat. For a million-parameter Atari network, you need &lt;a href="https://sesen.ai/blog/policy-gradients-reinforce-from-scratch" rel="noopener noreferrer"&gt;policy gradients&lt;/a&gt; or DQN.&lt;/p&gt;

&lt;h3&gt;
  
  
  When NOT to Use CEM
&lt;/h3&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;High-dimensional parameter spaces.&lt;/strong&gt; CEM samples grow exponentially less effective as dimensions increase. A 1000-parameter network needs enormous batch sizes&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Environments with sparse rewards.&lt;/strong&gt; If most policies score zero (e.g., Montezuma's Revenge), the elite set is just noise&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;When you need sample efficiency.&lt;/strong&gt; CEM uses 200 episodes per iteration vs REINFORCE using ~5 episodes per batch. If environment evaluations are expensive, gradient methods win&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Continuous action spaces with complex dynamics.&lt;/strong&gt; CEM with a linear policy can only learn linear decision boundaries. Problems requiring nonlinear policies need either a neural network (large parameter space) or a different algorithm&lt;/li&gt;
&lt;/ol&gt;

&lt;h3&gt;
  
  
  Connection to Genetic Algorithms
&lt;/h3&gt;

&lt;p&gt;If you read the &lt;a href="https://sesen.ai/blog/genetic-algorithms-from-scratch" rel="noopener noreferrer"&gt;genetic algorithms post&lt;/a&gt;, CEM will feel familiar. Both are population-based, derivative-free optimisation methods. The difference is in how they generate the next population:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Genetic algorithms&lt;/strong&gt; use crossover and mutation operators on individual solutions&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;CEM&lt;/strong&gt; fits a probability distribution to the elite set and samples from it&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;CEM is sometimes called an "estimation of distribution algorithm" (EDA). Instead of recombining individual solutions, it models the structure of good solutions as a distribution and samples new candidates from that model. For real-valued parameter optimisation, this Gaussian model is often more effective than genetic crossover.&lt;/p&gt;

&lt;h2&gt;
  
  
  Where This Comes From
&lt;/h2&gt;

&lt;p&gt;The Cross-Entropy Method was introduced by &lt;strong&gt;Reuven Rubinstein&lt;/strong&gt; in his 1999 paper &lt;a href="https://doi.org/10.1023/A:1010091220143" rel="noopener noreferrer"&gt;"The Cross-Entropy Method for Combinatorial and Continuous Optimization"&lt;/a&gt;. The name comes from the original application: minimising the cross-entropy (KL divergence) between a reference distribution and the optimal importance sampling distribution for rare-event simulation.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Core Idea
&lt;/h3&gt;

&lt;p&gt;Rubinstein's insight was that rare-event estimation and optimisation are essentially the same problem. To estimate &lt;code&gt;$P(S(X) \geq \gamma)$&lt;/code&gt; for a rare threshold &lt;code&gt;$\gamma$&lt;/code&gt;, you need to find a sampling distribution that concentrates on high-&lt;code&gt;$S(X)$&lt;/code&gt; regions. The CE method does this by iteratively:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;Drawing samples from the current distribution &lt;code&gt;$f(\cdot;\, v_t)$&lt;/code&gt;
&lt;/li&gt;
&lt;li&gt;Selecting the elite samples (those with &lt;code&gt;$S(X) \geq \gamma_t$&lt;/code&gt;)&lt;/li&gt;
&lt;li&gt;Updating the distribution parameters to minimise the KL divergence to the empirical elite distribution&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;For a Gaussian family, step 3 has a closed-form solution: the mean and variance of the elite samples. This is exactly what our implementation does.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Formal Algorithm
&lt;/h3&gt;

&lt;p&gt;From Rubinstein and Kroese (2004), the CEM update for a parametric family &lt;code&gt;$\{f(\cdot;\, v)\}$&lt;/code&gt; is:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Dv_%257Bt%252B1%257D%2520%253D%2520%255Carg%255Cmax_v%2520%255Cfrac%257B1%257D%257BN%257D%2520%255Csum_%257Bi%253D1%257D%255E%257BN%257D%2520I%255C%257BS%28X_i%29%2520%255Cgeq%2520%255Cgamma_t%255C%257D%2520%255Cln%2520f%28X_i%253B%255C%252C%2520v%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Dv_%257Bt%252B1%257D%2520%253D%2520%255Carg%255Cmax_v%2520%255Cfrac%257B1%257D%257BN%257D%2520%255Csum_%257Bi%253D1%257D%255E%257BN%257D%2520I%255C%257BS%28X_i%29%2520%255Cgeq%2520%255Cgamma_t%255C%257D%2520%255Cln%2520f%28X_i%253B%255C%252C%2520v%29" alt="equation" width="504" height="72"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;where &lt;code&gt;$I\{\cdot\}$&lt;/code&gt; is the indicator function selecting elite samples. For a multivariate Gaussian with diagonal covariance, this yields:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Cmu_%257Bt%252B1%257D%2520%253D%2520%255Cfrac%257B1%257D%257BN_%257B%255Ctext%257Belite%257D%257D%257D%2520%255Csum_%257Bi%2520%255Cin%2520%255Ctext%257Belite%257D%257D%2520X_i%252C%2520%255Cquad%2520%255Csigma_%257Bt%252B1%257D%255E2%2520%253D%2520%255Cfrac%257B1%257D%257BN_%257B%255Ctext%257Belite%257D%257D%257D%2520%255Csum_%257Bi%2520%255Cin%2520%255Ctext%257Belite%257D%257D%2520%28X_i%2520-%2520%255Cmu_%257Bt%252B1%257D%29%255E2" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Cmu_%257Bt%252B1%257D%2520%253D%2520%255Cfrac%257B1%257D%257BN_%257B%255Ctext%257Belite%257D%257D%257D%2520%255Csum_%257Bi%2520%255Cin%2520%255Ctext%257Belite%257D%257D%2520X_i%252C%2520%255Cquad%2520%255Csigma_%257Bt%252B1%257D%255E2%2520%253D%2520%255Cfrac%257B1%257D%257BN_%257B%255Ctext%257Belite%257D%257D%257D%2520%255Csum_%257Bi%2520%255Cin%2520%255Ctext%257Belite%257D%257D%2520%28X_i%2520-%2520%255Cmu_%257Bt%252B1%257D%29%255E2" alt="equation" width="573" height="64"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The sample mean and variance of the elite set. Elegantly simple.&lt;/p&gt;

&lt;h3&gt;
  
  
  From Rare Events to Tetris
&lt;/h3&gt;

&lt;p&gt;The method found its way into reinforcement learning through &lt;strong&gt;Szita and Lörincz (2006)&lt;/strong&gt;, &lt;a href="https://doi.org/10.1162/neco.2006.18.12.2936" rel="noopener noreferrer"&gt;"Learning Tetris Using the Noisy Cross-Entropy Method"&lt;/a&gt;. They made two key modifications for RL:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Noisy updates&lt;/strong&gt;: Adding decaying extra variance to prevent premature convergence (the &lt;code&gt;extra_std&lt;/code&gt; parameter in our code)&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Direct policy search&lt;/strong&gt;: Treating the policy's weight vector as the parameter to optimise, with episode return as the objective function&lt;/li&gt;
&lt;/ol&gt;

&lt;blockquote&gt;
&lt;p&gt;"The noisy cross-entropy method adds a time-decreasing noise term to avoid premature convergence of the variance to zero."&lt;br&gt;
&lt;em&gt;Szita and Lörincz (2006)&lt;/em&gt;&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;Their noisy CEM achieved record-breaking performance on Tetris at the time, outperforming methods that required orders of magnitude more computation. Our implementation follows their variant faithfully, including the linear noise decay schedule described in Section 3 of their paper.&lt;/p&gt;

&lt;h3&gt;
  
  
  Theoretical Properties
&lt;/h3&gt;

&lt;p&gt;Unlike policy gradient methods, CEM has no convergence guarantees to a local optimum. It is a heuristic. However, it has practical advantages:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Embarrassingly parallel&lt;/strong&gt;: All 200 evaluations per iteration are independent&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;No reward shaping needed&lt;/strong&gt;: Works with any scalar objective, even non-differentiable ones&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Robust to noisy evaluations&lt;/strong&gt;: The elite selection acts as a natural filter&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;The method's simplicity is also its limitation. As Rubinstein and Kroese note, the Gaussian parametric family assumes the optimal parameter region is unimodal. Multi-modal reward landscapes can trap CEM in a single mode.&lt;/p&gt;

&lt;h3&gt;
  
  
  Further Reading
&lt;/h3&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://doi.org/10.1023/A:1010091220143" rel="noopener noreferrer"&gt;Rubinstein (1999)&lt;/a&gt;, "The Cross-Entropy Method for Combinatorial and Continuous Optimization" - The original CE method paper&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://doi.org/10.1007/978-1-4757-4321-0" rel="noopener noreferrer"&gt;Rubinstein and Kroese (2004)&lt;/a&gt;, &lt;em&gt;The Cross-Entropy Method: A Unified Approach to Combinatorial Optimization, Monte Carlo Simulation, and Machine Learning&lt;/em&gt; - The comprehensive textbook&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://doi.org/10.1162/neco.2006.18.12.2936" rel="noopener noreferrer"&gt;Szita and Lörincz (2006)&lt;/a&gt;, "Learning Tetris Using the Noisy Cross-Entropy Method" - The noisy variant for RL&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://arxiv.org/abs/1703.03864" rel="noopener noreferrer"&gt;Salimans et al. (2017)&lt;/a&gt;, "Evolution Strategies as a Scalable Alternative to Reinforcement Learning" - Modern evolution strategies at scale (OpenAI)&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Sutton and Barto (2018)&lt;/strong&gt;, Ch. 13 - Policy gradient methods for comparison&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Try It Yourself
&lt;/h2&gt;

&lt;p&gt;The &lt;a href="https://colab.research.google.com/github/zhubarb/sesen_ai_ml_tutorials/blob/main/notebooks/reinforcement-learning/cross_entropy_method.ipynb" rel="noopener noreferrer"&gt;interactive notebook&lt;/a&gt; includes exercises:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Elite fraction sweep&lt;/strong&gt;: Try &lt;code&gt;elite_frac&lt;/code&gt; values of 0.01, 0.1, 0.2, and 0.5. How does selectivity affect convergence speed and stability?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Noisy vs vanilla CEM&lt;/strong&gt;: Set &lt;code&gt;extra_std=0&lt;/code&gt; and compare convergence. Does the noisy variant help on CartPole, or only on harder problems?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Neural network policy&lt;/strong&gt;: Replace the linear policy with a small neural net (8 hidden units). How many CEM iterations does it take to solve CartPole now? At what network size does CEM become impractical?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Different environments&lt;/strong&gt;: Try CEM on &lt;code&gt;Acrobot-v1&lt;/code&gt; or &lt;code&gt;MountainCar-v0&lt;/code&gt;. Which environments does CEM handle well, and which expose its limitations?&lt;/li&gt;
&lt;/ol&gt;

&lt;h2&gt;
  
  
  Interactive Tools
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/q-learning-visualizer" rel="noopener noreferrer"&gt;Q-Learning Visualiser&lt;/a&gt; — See value-based RL in action and compare it with the policy search approach of CEM&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Related Posts
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/genetic-algorithms-from-scratch" rel="noopener noreferrer"&gt;Genetic Algorithms: From Line Fitting to the Travelling Salesman&lt;/a&gt; - Another population-based, derivative-free optimisation method. CEM replaces crossover and mutation with distribution fitting.&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/policy-gradients-reinforce-from-scratch" rel="noopener noreferrer"&gt;Policy Gradients: REINFORCE from Scratch with NumPy&lt;/a&gt; - The gradient-based alternative for policy search. Uses backpropagation through the policy, which scales to large networks but requires differentiable objectives.&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/deep-q-networks-experience-replay-target-networks" rel="noopener noreferrer"&gt;Deep Q-Networks: Experience Replay and Target Networks&lt;/a&gt; - Value-based RL with neural networks. A fundamentally different approach that learns what states are valuable rather than directly searching for good policies.&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Frequently Asked Questions
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Why is the Cross-Entropy Method called "cross-entropy" if it does not use a loss function?
&lt;/h3&gt;

&lt;p&gt;The name comes from the original application in rare-event simulation, where the algorithm minimises the cross-entropy (KL divergence) between the current sampling distribution and the optimal importance sampling distribution. In the reinforcement learning context, the name persists even though the update reduces to simply computing the mean and variance of the elite samples.&lt;/p&gt;

&lt;h3&gt;
  
  
  How does CEM compare to random search?
&lt;/h3&gt;

&lt;p&gt;Both methods sample candidate policies each iteration, but random search draws from a fixed distribution every time, while CEM updates its distribution based on the best-performing candidates. This directed search means CEM builds on previous successes, converging to good solutions far faster than random search on problems with structure to exploit.&lt;/p&gt;

&lt;h3&gt;
  
  
  Can CEM solve problems with continuous action spaces?
&lt;/h3&gt;

&lt;p&gt;CEM can optimise over continuous policy parameters, but the policy itself determines how actions are generated. A linear policy with CEM-optimised weights can only produce binary or discrete decisions. For truly continuous action spaces with complex dynamics, you would need a more expressive policy architecture, which increases the parameter count and makes CEM less practical.&lt;/p&gt;

&lt;h3&gt;
  
  
  What is the role of the elite fraction hyperparameter?
&lt;/h3&gt;

&lt;p&gt;The elite fraction controls how selective the algorithm is when choosing which candidates inform the next distribution. A smaller fraction (e.g. 0.01) converges faster but risks collapsing onto a local optimum. A larger fraction (e.g. 0.5) explores more broadly but converges more slowly. A value around 0.2 is a common default that balances exploitation and exploration.&lt;/p&gt;

&lt;h3&gt;
  
  
  Why does the noisy CEM variant add extra variance that decays over time?
&lt;/h3&gt;

&lt;p&gt;Without extra variance, the sampling distribution can collapse to near-zero variance too quickly, trapping the search around a potentially suboptimal solution. The decaying noise keeps exploration alive in early iterations when the algorithm is still uncertain about the best region, then gradually disappears to allow precise convergence in later iterations.&lt;/p&gt;

</description>
      <category>reinforcementlearning</category>
      <category>optimisation</category>
    </item>
    <item>
      <title>PCR vs PLS: When Fewer Features Beat More</title>
      <dc:creator>Berkan Sesen</dc:creator>
      <pubDate>Sun, 19 Apr 2026 15:38:56 +0000</pubDate>
      <link>https://dev.to/berkan_sesen/pcr-vs-pls-when-fewer-features-beat-more-2plp</link>
      <guid>https://dev.to/berkan_sesen/pcr-vs-pls-when-fewer-features-beat-more-2plp</guid>
      <description>&lt;p&gt;How much should a baseball team pay its players? The 1986 Major League season gives us 263 hitters with 19 statistics each: at-bats, hits, home runs, years played, and more. Predicting salary from performance sounds like a textbook regression problem, but 19 correlated features make it anything but. Throw them all into a &lt;a href="https://sesen.ai/blog/linear-regression-five-ways" rel="noopener noreferrer"&gt;linear regression&lt;/a&gt; and the model fits the training data beautifully but falls apart on held-out players. The coefficient estimates are wildly unstable, and salary predictions swing by thousands on minor input changes.&lt;/p&gt;

&lt;p&gt;The fix is not a fancier model. It is &lt;em&gt;fewer features&lt;/em&gt;, chosen more carefully. This post covers two classic strategies for doing exactly that: Principal Component Regression (PCR) and Partial Least Squares (PLS).&lt;/p&gt;

&lt;p&gt;By the end, you'll understand how both methods compress correlated features into a handful of components, why PLS typically needs fewer components than PCR, and when each approach is the right tool.&lt;/p&gt;

&lt;h2&gt;
  
  
  Quick Win: Predict Salaries with 6 Features Instead of 19
&lt;/h2&gt;

&lt;p&gt;Click the badge to open the interactive notebook:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://colab.research.google.com/github/zhubarb/sesen_ai_ml_tutorials/blob/main/notebooks/supervised/pcr_vs_pls.ipynb" rel="noopener noreferrer"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fcolab.research.google.com%2Fassets%2Fcolab-badge.svg" alt="Open In Colab" width="117" height="20"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;We'll use the classic ISLR Hitters dataset: 263 baseball players with 19 features (at-bats, hits, home runs, years played, etc.) predicting salary in thousands of dollars.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;pandas&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;pd&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn.preprocessing&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;scale&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn.decomposition&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;PCA&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn.linear_model&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;LinearRegression&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn.cross_decomposition&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;PLSRegression&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn.model_selection&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;KFold&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cross_val_score&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;train_test_split&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn.metrics&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;mean_squared_error&lt;/span&gt;

&lt;span class="c1"&gt;# Load and prepare the Hitters dataset
&lt;/span&gt;&lt;span class="n"&gt;df&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pd&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;read_csv&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;https://raw.githubusercontent.com/selva86/datasets/master/Hitters.csv&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;
&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;dropna&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;dummies&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pd&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;get_dummies&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;df&lt;/span&gt;&lt;span class="p"&gt;[[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;League&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Division&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;NewLeague&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]])&lt;/span&gt;
&lt;span class="n"&gt;y&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;df&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Salary&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;span class="n"&gt;X&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pd&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;concat&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;
    &lt;span class="n"&gt;df&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;drop&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Salary&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;League&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Division&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;NewLeague&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;astype&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;float64&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
    &lt;span class="n"&gt;dummies&lt;/span&gt;&lt;span class="p"&gt;[[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;League_N&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Division_W&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;NewLeague_N&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]]&lt;/span&gt;
&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Train/test split
&lt;/span&gt;&lt;span class="n"&gt;X_train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;X_test&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_train&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_test&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;train_test_split&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;test_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# PCR: PCA on scaled training data, then regression
&lt;/span&gt;&lt;span class="n"&gt;pca&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;PCA&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;X_train_pc&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pca&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fit_transform&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;scale&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_train&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;X_test_pc&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pca&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;transform&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;scale&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_test&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

&lt;span class="c1"&gt;# Use 10-fold CV to find the best number of components
&lt;/span&gt;&lt;span class="n"&gt;kf&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;KFold&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_splits&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shuffle&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;regr&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;LinearRegression&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;mse_by_ncomp&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;score&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="nf"&gt;cross_val_score&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
        &lt;span class="n"&gt;regr&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;X_train_pc&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;y_train&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;to_numpy&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt;
        &lt;span class="n"&gt;cv&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;kf&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;scoring&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;neg_mean_squared_error&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;
    &lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="n"&gt;mse_by_ncomp&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;score&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;best_k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;argmin&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mse_by_ncomp&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Best PCR components: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;best_k&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;CV MSE: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;mse_by_ncomp&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;best_k&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;,.&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Evaluate on test set
&lt;/span&gt;&lt;span class="n"&gt;regr&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_train_pc&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;best_k&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;y_train&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;pcr_test_mse&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;mean_squared_error&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y_test&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;regr&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;predict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_test_pc&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;best_k&lt;/span&gt;&lt;span class="p"&gt;]))&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;PCR test MSE (&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;best_k&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; components): &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;pcr_test_mse&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;,.&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Compare to full OLS
&lt;/span&gt;&lt;span class="n"&gt;regr_full&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;LinearRegression&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;regr_full&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_train_pc&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y_train&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ols_test_mse&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;mean_squared_error&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y_test&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;regr_full&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;predict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_test_pc&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Full OLS test MSE (19 features): &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;ols_test_mse&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;,.&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;strong&gt;The result:&lt;/strong&gt; PCR with just 6 components achieves a test MSE of ~112,000, beating full OLS (test MSE ~117,000) using all 19 features. Fewer features, better predictions.&lt;/p&gt;

&lt;h3&gt;
  
  
  PCR vs PLS: The Key Difference
&lt;/h3&gt;

&lt;p&gt;Now let's try PLS, which uses the target variable during dimension reduction:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# PLS: find the best number of components via CV
&lt;/span&gt;&lt;span class="n"&gt;pls_mse&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;pls&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;PLSRegression&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_components&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;score&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="nf"&gt;cross_val_score&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
        &lt;span class="n"&gt;pls&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nf"&gt;scale&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_train&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;y_train&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;to_numpy&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt;
        &lt;span class="n"&gt;cv&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;kf&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;scoring&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;neg_mean_squared_error&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;
    &lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="n"&gt;pls_mse&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;score&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;best_pls_k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;argmin&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;pls_mse&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;
&lt;span class="n"&gt;pls_best&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;PLSRegression&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_components&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;pls_best&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;scale&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_train&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;y_train&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;pls_test_mse&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;mean_squared_error&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y_test&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;pls_best&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;predict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;scale&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_test&lt;/span&gt;&lt;span class="p"&gt;)))&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;PLS test MSE (2 components): &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;pls_test_mse&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;,.&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;strong&gt;PLS with just 2 components&lt;/strong&gt; achieves a test MSE of ~105,000, beating both PCR and OLS. That is the power of supervised dimension reduction: PLS finds the directions that matter for the target, not just the directions of maximum variance.&lt;/p&gt;

&lt;h2&gt;
  
  
  What Just Happened?
&lt;/h2&gt;

&lt;p&gt;Both methods solve the same problem: your 19 features are correlated (career stats like CAtBat, CHits, CRuns all move together), so fitting a separate coefficient for each one leads to noisy, unstable estimates. The solution is to compress correlated features into a smaller set of &lt;strong&gt;components&lt;/strong&gt; before regressing.&lt;/p&gt;

&lt;p&gt;The difference is &lt;em&gt;how&lt;/em&gt; they choose those components.&lt;/p&gt;

&lt;h3&gt;
  
  
  PCR: Unsupervised, Then Regress
&lt;/h3&gt;

&lt;p&gt;PCR works in two steps:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;PCA&lt;/strong&gt; finds the directions of maximum variance in &lt;code&gt;$X$&lt;/code&gt;, ignoring &lt;code&gt;$y$&lt;/code&gt; entirely&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Linear regression&lt;/strong&gt; fits &lt;code&gt;$y$&lt;/code&gt; on the top &lt;code&gt;$k$&lt;/code&gt; principal components
&lt;/li&gt;
&lt;/ol&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Step 1: PCA finds directions of maximum variance
&lt;/span&gt;&lt;span class="n"&gt;pca&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;PCA&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;X_train_pc&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pca&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fit_transform&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;scale&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_train&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;  &lt;span class="c1"&gt;# 19 features → 19 PCs
&lt;/span&gt;
&lt;span class="c1"&gt;# Step 2: Regress salary on just the first k PCs
&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;6&lt;/span&gt;
&lt;span class="n"&gt;regr&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;LinearRegression&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;regr&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_train_pc&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;y_train&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The first principal component captures the direction along which the features vary the most. In our Hitters data, PC1 captures 39.9% of the total variance, and by PC7 we're at 93.4%.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F39iav845wbfyqmfdq7yq.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F39iav845wbfyqmfdq7yq.webp" alt="Explained variance per principal component (bars) and cumulative variance (line). The first 7 components capture over 93% of the total variance in the 19 features." width="800" height="417"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;But here's the catch: the directions of maximum variance in &lt;code&gt;$X$&lt;/code&gt; are not necessarily the directions most useful for predicting &lt;code&gt;$y$&lt;/code&gt;. PC1 might capture the spread between high-career and low-career players, but if salary depends more on a subtle interaction between recent performance and league, that signal could be buried in PC8 or PC12.&lt;/p&gt;

&lt;h3&gt;
  
  
  PLS: Supervised from the Start
&lt;/h3&gt;

&lt;p&gt;PLS finds directions that simultaneously:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;Explain variance in &lt;code&gt;$X$&lt;/code&gt; (like PCA)&lt;/li&gt;
&lt;li&gt;Correlate with &lt;code&gt;$y$&lt;/code&gt; (unlike PCA)
&lt;/li&gt;
&lt;/ul&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# PLS finds directions that maximise covariance between X and y
&lt;/span&gt;&lt;span class="n"&gt;pls&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;PLSRegression&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_components&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;pls&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;scale&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_train&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;y_train&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;predictions&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pls&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;predict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;scale&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_test&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This is why PLS needs only 2 components where PCR needs 6. PLS searches directly for the features that predict salary, while PCR has to hope that the high-variance directions in &lt;code&gt;$X$&lt;/code&gt; also happen to predict &lt;code&gt;$y$&lt;/code&gt;.&lt;/p&gt;

&lt;h3&gt;
  
  
  Choosing the Number of Components
&lt;/h3&gt;

&lt;p&gt;Both methods use 10-fold cross-validation to select the number of components:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;matplotlib.pyplot&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;

&lt;span class="n"&gt;fig&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;ax1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ax2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;subplots&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;12&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

&lt;span class="c1"&gt;# PCR
&lt;/span&gt;&lt;span class="n"&gt;ax1&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;20&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;mse_by_ncomp&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;-o&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;steelblue&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;markersize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax1&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;axvline&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;best_k&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;red&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;linestyle&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;--&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.7&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Best: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;best_k&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax1&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_xlabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Number of Components&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax1&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_ylabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;10-Fold CV MSE&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax1&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_title&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;PCR&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax1&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;legend&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;ax1&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;grid&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# PLS
&lt;/span&gt;&lt;span class="n"&gt;ax2&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;20&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;pls_mse&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;-s&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;darkorange&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;markersize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax2&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;axvline&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;green&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;linestyle&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;--&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.7&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Selected: 2 (parsimonious)&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax2&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;axvline&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;best_pls_k&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;red&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;linestyle&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;:&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;label&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;CV min: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;best_pls_k&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax2&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_xlabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Number of Components&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax2&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_ylabel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;10-Fold CV MSE&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax2&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_title&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;PLS&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax2&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;legend&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;fontsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;9&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ax2&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;grid&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tight_layout&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;show&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F2yslidjyargov31umd5p.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F2yslidjyargov31umd5p.webp" alt="Cross-validation MSE curves for PCR (left, minimum at 6 components) and PLS (right, CV minimum at 11 but 2 selected for parsimony). PLS reaches competitive performance with far fewer components." width="800" height="328"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The PCR curve dips at 6 components and rises again: adding noisy components &lt;em&gt;hurts&lt;/em&gt; predictions. The PLS curve is more interesting: the strict CV minimum is at 11 components, but 2 components achieve nearly the same MSE (143,564 vs 142,554). We select 2 because the simpler model generalises better on the test set (MSE 104,839 vs 106,891 with 11). This is a common pattern: when the CV curve is flat near the minimum, prefer the simpler model.&lt;/p&gt;

&lt;h3&gt;
  
  
  A Peek Inside the Components
&lt;/h3&gt;

&lt;p&gt;What do these principal components actually capture? The PCA loadings reveal which original features contribute to each component:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;loadings&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pd&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;DataFrame&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;pca&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;components_&lt;/span&gt;&lt;span class="p"&gt;[:&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;T&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;columns&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;PC&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;)],&lt;/span&gt;
    &lt;span class="n"&gt;index&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;columns&lt;/span&gt;
&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;loadings&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sort_values&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;PC1&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ascending&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;round&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fglpi7ibwgf4q7n8f80sh.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fglpi7ibwgf4q7n8f80sh.webp" alt="PCA loadings heatmap showing how each of the 19 features contributes to the first 5 principal components. Career statistics (CRuns, CRBI, CHits) dominate PC1, current-season stats (AtBat, Hits, Runs) dominate PC2, and league indicators dominate PC3." width="800" height="666"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The heatmap reveals the correlation structure clearly. PC1 (39.9% of variance) is dominated by career statistics: CRuns, CRBI, CHits, CAtBat, and CHmRun all have loadings above 0.30. PC2 (21.5%) separates current-season stats (AtBat, Hits, Runs with positive loadings) from career longevity (Years with a negative loading). PC3 picks up the league indicator variables. PCA compresses these correlated groups into single components, which is exactly why dimension reduction works here.&lt;/p&gt;

&lt;h2&gt;
  
  
  Going Deeper
&lt;/h2&gt;

&lt;h3&gt;
  
  
  The Mathematics
&lt;/h3&gt;

&lt;p&gt;&lt;strong&gt;PCR&lt;/strong&gt; decomposes &lt;code&gt;$X$&lt;/code&gt; using PCA:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DX%2520%253D%2520U%2520%255CSigma%2520V%255ET" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DX%2520%253D%2520U%2520%255CSigma%2520V%255ET" alt="equation" width="124" height="22"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;where &lt;code&gt;$V$&lt;/code&gt; contains the principal component directions (eigenvectors of &lt;code&gt;$X^TX$&lt;/code&gt;). We keep only the first &lt;code&gt;$k$&lt;/code&gt; columns of &lt;code&gt;$Z = XV_k$&lt;/code&gt; and regress:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Chat%257By%257D%2520%253D%2520Z_k%2520%255Chat%257B%255Cbeta%257D_k%2520%253D%2520X%2520V_k%2520%28V_k%255ET%2520X%255ET%2520X%2520V_k%29%255E%257B-1%257D%2520V_k%255ET%2520X%255ET%2520y" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Chat%257By%257D%2520%253D%2520Z_k%2520%255Chat%257B%255Cbeta%257D_k%2520%253D%2520X%2520V_k%2520%28V_k%255ET%2520X%255ET%2520X%2520V_k%29%255E%257B-1%257D%2520V_k%255ET%2520X%255ET%2520y" alt="equation" width="418" height="29"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;This is equivalent to OLS on the reduced feature set. The key insight: since the PCs are orthogonal, the regression coefficients don't change when you add or remove components. Each component's contribution is independent.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;PLS&lt;/strong&gt; maximises the covariance between &lt;code&gt;$X$&lt;/code&gt; and &lt;code&gt;$y$&lt;/code&gt;:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Dw_1%2520%253D%2520%255Carg%255Cmax_%257B%255C%257Cw%255C%257C%253D1%257D%2520%255Ctext%257BCov%257D%28Xw%252C%255C%252C%2520y%29%2520%253D%2520%255Carg%255Cmax_%257B%255C%257Cw%255C%257C%253D1%257D%2520%28Xw%29%255ET%2520y" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257Dw_1%2520%253D%2520%255Carg%255Cmax_%257B%255C%257Cw%255C%257C%253D1%257D%2520%255Ctext%257BCov%257D%28Xw%252C%255C%252C%2520y%29%2520%253D%2520%255Carg%255Cmax_%257B%255C%257Cw%255C%257C%253D1%257D%2520%28Xw%29%255ET%2520y" alt="equation" width="495" height="43"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The first PLS direction &lt;code&gt;$w_1$&lt;/code&gt; is simply &lt;code&gt;$X^T y$&lt;/code&gt; (normalised): the covariance between each feature and the target. Subsequent directions are found by deflating &lt;code&gt;$X$&lt;/code&gt; and repeating.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F8yg25ykrw159t7kp76hm.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F8yg25ykrw159t7kp76hm.webp" alt="PCR finds directions of maximum variance in X (unsupervised, then regresses on y), while PLS finds directions that maximise covariance between X and y (supervised from the start)." width="800" height="446"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3&gt;
  
  
  When PCR Wins, When PLS Wins
&lt;/h3&gt;

&lt;p&gt;&lt;strong&gt;PCR is better when:&lt;/strong&gt;&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;The high-variance directions in &lt;code&gt;$X$&lt;/code&gt; genuinely predict &lt;code&gt;$y$&lt;/code&gt; (common in spectroscopy, genomics)&lt;/li&gt;
&lt;li&gt;You have many features and few observations (PCA provides stable variance estimates)&lt;/li&gt;
&lt;li&gt;You want an unsupervised feature extraction that you can reuse across multiple targets&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;strong&gt;PLS is better when:&lt;/strong&gt;&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;The predictive signal sits in low-variance directions of &lt;code&gt;$X$&lt;/code&gt;
&lt;/li&gt;
&lt;li&gt;You have a single target and want the most efficient compression&lt;/li&gt;
&lt;li&gt;Your features include many irrelevant high-variance variables&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;In our Hitters example, PLS wins convincingly: 2 components vs 6, and lower test error. The salary signal does not align perfectly with the directions of maximum variance in the batting statistics.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Bias-Variance Tradeoff
&lt;/h3&gt;

&lt;p&gt;Both methods trade bias for lower variance:&lt;/p&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Method&lt;/th&gt;
&lt;th&gt;Components&lt;/th&gt;
&lt;th&gt;Test MSE&lt;/th&gt;
&lt;th&gt;RMSE ($k)&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;Full OLS&lt;/td&gt;
&lt;td&gt;19&lt;/td&gt;
&lt;td&gt;117,301&lt;/td&gt;
&lt;td&gt;342&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;PCR&lt;/td&gt;
&lt;td&gt;6&lt;/td&gt;
&lt;td&gt;112,167&lt;/td&gt;
&lt;td&gt;335&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;PLS&lt;/td&gt;
&lt;td&gt;2&lt;/td&gt;
&lt;td&gt;104,839&lt;/td&gt;
&lt;td&gt;324&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Ridge&lt;/td&gt;
&lt;td&gt;--&lt;/td&gt;
&lt;td&gt;99,741&lt;/td&gt;
&lt;td&gt;316&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;p&gt;Full OLS uses all 19 features but has high variance (unstable coefficients). PCR and PLS introduce some bias by discarding information, but the reduction in variance more than compensates. Ridge regression (included for comparison) achieves the lowest error by shrinking coefficients rather than discarding components.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fzwgoyjjgn3lq6d7iesax.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fzwgoyjjgn3lq6d7iesax.webp" alt="Bar chart comparing test MSE across four methods: Full OLS, PCR with 6 components, PLS with 2 components, and Ridge regression." width="800" height="515"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The practical message: when features are correlated, you rarely need all of them. The question is whether to reduce dimensions unsupervised (PCR), supervised (PLS), or regularise without reducing dimensions at all (Ridge).&lt;/p&gt;

&lt;h3&gt;
  
  
  When NOT to Use PCR or PLS
&lt;/h3&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Few features, many observations.&lt;/strong&gt; If &lt;code&gt;$p \ll n$&lt;/code&gt;, multicollinearity is less of a problem and OLS works fine.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Interpretability is critical.&lt;/strong&gt; The principal components are linear combinations of all features, so individual feature effects are obscured. If you need to say "an extra home run is worth $X in salary," use Ridge or Lasso instead.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Non-linear relationships.&lt;/strong&gt; PCR and PLS are linear methods. For non-linear patterns, consider &lt;a href="https://sesen.ai/blog/gaussian-process-regression-from-scratch" rel="noopener noreferrer"&gt;Gaussian process regression&lt;/a&gt; or tree-based models.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Sparse signals.&lt;/strong&gt; If only a few features matter and the rest are noise, Lasso (L1 regularisation) does feature &lt;em&gt;selection&lt;/em&gt; rather than feature &lt;em&gt;combination&lt;/em&gt;, which is usually more effective.&lt;/li&gt;
&lt;/ol&gt;

&lt;h2&gt;
  
  
  Deep Dive: The Papers
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Principal Component Regression
&lt;/h3&gt;

&lt;p&gt;The idea of using principal components as regression predictors dates to &lt;strong&gt;Massy (1965)&lt;/strong&gt;, &lt;a href="https://doi.org/10.1080/01621459.1965.10480810" rel="noopener noreferrer"&gt;"Principal Components Regression in Exploratory Statistical Research"&lt;/a&gt;, published in the &lt;em&gt;Journal of the American Statistical Association&lt;/em&gt;.&lt;/p&gt;

&lt;p&gt;Massy was working on marketing research problems where survey data had dozens of correlated variables. He proposed a two-step procedure: extract principal components, then regress on the top &lt;code&gt;$k$&lt;/code&gt;. His key insight:&lt;/p&gt;

&lt;blockquote&gt;
&lt;p&gt;"By using the principal components as the independent variables in the regression, we avoid the multicollinearity problem since the components are orthogonal."&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;The underlying PCA dates back further to &lt;strong&gt;Hotelling (1933)&lt;/strong&gt;, &lt;a href="https://doi.org/10.1037/h0071325" rel="noopener noreferrer"&gt;"Analysis of a complex of statistical variables into principal components"&lt;/a&gt;, &lt;em&gt;Journal of Educational Psychology&lt;/em&gt;. Hotelling formalised the eigenvalue decomposition of the covariance matrix, though the core idea appeared even earlier in Pearson (1901).&lt;/p&gt;

&lt;h3&gt;
  
  
  Partial Least Squares
&lt;/h3&gt;

&lt;p&gt;PLS was developed by &lt;strong&gt;Herman Wold&lt;/strong&gt; in the 1960s and 1970s, originally for econometrics. The foundational paper is &lt;strong&gt;Wold (1975)&lt;/strong&gt;, "Soft modelling by latent variables: the non-linear iterative partial least squares (NIPALS) approach," in &lt;em&gt;Perspectives in Probability and Statistics&lt;/em&gt;.&lt;/p&gt;

&lt;p&gt;Herman's son, &lt;strong&gt;Svante Wold&lt;/strong&gt;, later popularised PLS in chemometrics with a landmark review: &lt;strong&gt;Wold, Sjostrom &amp;amp; Eriksson (2001)&lt;/strong&gt;, &lt;a href="https://doi.org/10.1016/S0169-7439(01)00155-1" rel="noopener noreferrer"&gt;"PLS-regression: a basic tool of chemometrics"&lt;/a&gt;, &lt;em&gt;Chemometrics and Intelligent Laboratory Systems&lt;/em&gt;.&lt;/p&gt;

&lt;p&gt;The modern computational algorithm used in most implementations (including sklearn) is &lt;strong&gt;SIMPLS&lt;/strong&gt; by &lt;strong&gt;de Jong (1993)&lt;/strong&gt;, &lt;a href="https://doi.org/10.1016/0169-7439(93)85002-X" rel="noopener noreferrer"&gt;"SIMPLS: An alternative approach to partial least squares regression"&lt;/a&gt;. de Jong's algorithm computes PLS components without the iterative deflation step, making it both faster and numerically more stable.&lt;/p&gt;

&lt;h3&gt;
  
  
  The ISLR Connection
&lt;/h3&gt;

&lt;p&gt;This tutorial is based on the lab exercise in &lt;strong&gt;James, Witten, Hastie &amp;amp; Tibshirani (2021)&lt;/strong&gt;, &lt;a href="https://www.statlearning.com/" rel="noopener noreferrer"&gt;&lt;em&gt;An Introduction to Statistical Learning&lt;/em&gt;&lt;/a&gt;, Chapter 6. ISLR provides an excellent treatment of PCR and PLS in the context of the bias-variance tradeoff, alongside Ridge and Lasso regression.&lt;/p&gt;

&lt;p&gt;The Hitters dataset used here has become a standard benchmark for comparing regularisation and dimension reduction methods. With 19 correlated features, it sits in the sweet spot where these methods make a visible difference.&lt;/p&gt;

&lt;h3&gt;
  
  
  Further Reading
&lt;/h3&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;ISLR Chapter 6&lt;/strong&gt; (free online) - PCR, PLS, Ridge, and Lasso side by side&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Hastie, Tibshirani &amp;amp; Friedman (2009)&lt;/strong&gt;, &lt;em&gt;The Elements of Statistical Learning&lt;/em&gt;, Chapter 3.5 - Rigorous treatment&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Abdi (2010)&lt;/strong&gt;, &lt;a href="https://doi.org/10.1002/wics.51" rel="noopener noreferrer"&gt;"Partial least squares regression and projection on latent structure regression"&lt;/a&gt; - Excellent modern overview&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Try It Yourself
&lt;/h2&gt;

&lt;p&gt;The &lt;a href="https://colab.research.google.com/github/zhubarb/sesen_ai_ml_tutorials/blob/main/notebooks/supervised/pcr_vs_pls.ipynb" rel="noopener noreferrer"&gt;interactive notebook&lt;/a&gt; includes exercises:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Scree plot.&lt;/strong&gt; Plot the explained variance per component and the cumulative curve. How many components do you need to capture 95% of the variance?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;PLS loadings.&lt;/strong&gt; Compare the PLS weight vectors to the PCA loadings. Which features does PLS prioritise that PCA does not?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Ridge vs PCR.&lt;/strong&gt; Add a Ridge regression (with &lt;code&gt;RidgeCV&lt;/code&gt;) to the comparison. In what sense is Ridge a "soft" version of PCR?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Log-transform the target.&lt;/strong&gt; Salary is right-skewed. Does predicting &lt;code&gt;$\log(\text{Salary})$&lt;/code&gt; change which method wins?&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;Understanding PCR and PLS builds directly on &lt;a href="https://sesen.ai/blog/linear-regression-five-ways" rel="noopener noreferrer"&gt;linear regression&lt;/a&gt; and connects to &lt;a href="https://sesen.ai/blog/from-mle-to-bayesian-inference" rel="noopener noreferrer"&gt;Bayesian inference&lt;/a&gt; through the regularisation-as-prior interpretation. When the linear assumption breaks down, &lt;a href="https://sesen.ai/blog/gaussian-process-regression-from-scratch" rel="noopener noreferrer"&gt;Gaussian process regression&lt;/a&gt; offers a non-parametric alternative that handles high-dimensional inputs gracefully.&lt;/p&gt;

&lt;h2&gt;
  
  
  Interactive Tools
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/regression-playground" rel="noopener noreferrer"&gt;Regression Playground&lt;/a&gt; — Fit and compare regression models interactively in the browser&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Related Posts
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/linear-regression-five-ways" rel="noopener noreferrer"&gt;Linear Regression Five Ways&lt;/a&gt; — The foundation both PCR and PLS build on&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/lda-vs-pca-supervised-unsupervised-dimensionality-reduction" rel="noopener noreferrer"&gt;LDA vs PCA: Supervised vs Unsupervised Dimensionality Reduction&lt;/a&gt; — The classification counterpart to PCR vs PLS&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/from-mle-to-bayesian-inference" rel="noopener noreferrer"&gt;From MLE to Bayesian Inference&lt;/a&gt; — How regularisation connects to Bayesian priors&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/gaussian-process-regression-from-scratch" rel="noopener noreferrer"&gt;Gaussian Process Regression from Scratch&lt;/a&gt; — A non-parametric alternative when linearity breaks down&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Frequently Asked Questions
&lt;/h2&gt;

&lt;h3&gt;
  
  
  What is the key difference between PCR and PLS?
&lt;/h3&gt;

&lt;p&gt;PCR finds directions of maximum variance in the features without considering the target variable, then regresses on those directions. PLS finds directions that maximise the covariance between the features and the target simultaneously. Because PLS is supervised from the start, it typically needs fewer components to achieve the same predictive performance.&lt;/p&gt;

&lt;h3&gt;
  
  
  When should I use PCR instead of PLS?
&lt;/h3&gt;

&lt;p&gt;PCR is preferable when the high-variance directions in your features genuinely predict the target, which is common in spectroscopy and genomics. It is also useful when you want an unsupervised feature extraction that can be reused across multiple target variables. PLS is better when the predictive signal sits in low-variance directions or when many high-variance features are irrelevant to the outcome.&lt;/p&gt;

&lt;h3&gt;
  
  
  How do I choose the number of components for PCR or PLS?
&lt;/h3&gt;

&lt;p&gt;Use k-fold cross-validation to evaluate predictive performance at each number of components and select the value that minimises the cross-validation error. When the error curve is flat near the minimum, prefer the simpler model with fewer components, as it tends to generalise better on unseen data.&lt;/p&gt;

&lt;h3&gt;
  
  
  Why did PLS with 2 components beat PCR with 6 components on the Hitters dataset?
&lt;/h3&gt;

&lt;p&gt;The salary signal in the baseball data does not align well with the directions of maximum variance in the batting statistics. Career statistics dominate the first few principal components, but salary depends on a subtler combination of recent performance and league factors. PLS finds these salary-relevant directions directly, so it needs far fewer components.&lt;/p&gt;

&lt;h3&gt;
  
  
  How does PCR compare to Ridge regression?
&lt;/h3&gt;

&lt;p&gt;Both methods address multicollinearity, but in different ways. PCR discards the least important principal components entirely, introducing a hard cutoff. Ridge regression shrinks all coefficients towards zero without discarding any, acting as a soft version of dimension reduction. Ridge often achieves lower test error because it retains some information from every direction.&lt;/p&gt;

&lt;h3&gt;
  
  
  Can I interpret individual feature effects with PCR or PLS?
&lt;/h3&gt;

&lt;p&gt;Not directly. The components are linear combinations of all original features, so individual feature effects are obscured. If you need to say that a specific feature is worth a certain amount, use Ridge or Lasso regression instead, which produce interpretable coefficients for each original variable.&lt;/p&gt;

</description>
      <category>supervisedlearning</category>
      <category>statistics</category>
    </item>
    <item>
      <title>Text Classification from Scratch: TF-IDF and Naive Bayes</title>
      <dc:creator>Berkan Sesen</dc:creator>
      <pubDate>Fri, 17 Apr 2026 12:46:47 +0000</pubDate>
      <link>https://dev.to/berkan_sesen/text-classification-from-scratch-tf-idf-and-naive-bayes-3pff</link>
      <guid>https://dev.to/berkan_sesen/text-classification-from-scratch-tf-idf-and-naive-bayes-3pff</guid>
      <description>&lt;p&gt;Every morning, your inbox separates spam from real email. News apps sort articles into sports, tech, and politics. Customer support systems route tickets to the right team. Behind all of these is text classification: teaching a machine to read a document and assign it a category.&lt;/p&gt;

&lt;p&gt;The building blocks are simpler than you might expect. You need a way to convert text into numbers (TF-IDF), a classifier that works well with sparse, high-dimensional data (Naive Bayes), and a few lines of code to tie them together. No deep learning, no GPUs, no embeddings.&lt;/p&gt;

&lt;p&gt;By the end of this post, you'll classify news articles into 20 categories with 77% accuracy using just 10 lines of Python, then push that to 84% with hyperparameter tuning. You'll understand exactly how TF-IDF works and why the "naive" independence assumption in Naive Bayes is a feature, not a bug.&lt;/p&gt;

&lt;h2&gt;
  
  
  Let's Build It
&lt;/h2&gt;

&lt;p&gt;Click the badge to open the interactive notebook:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://colab.research.google.com/github/zhubarb/sesen_ai_ml_tutorials/blob/main/notebooks/nlp/tfidf_naive_bayes.ipynb" rel="noopener noreferrer"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fcolab.research.google.com%2Fassets%2Fcolab-badge.svg" alt="Open In Colab" width="117" height="20"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Here's the complete classifier. We use scikit-learn's 20 Newsgroups dataset, which contains around 18,000 posts across 20 topics, from computer graphics to religion to space exploration:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn.datasets&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;fetch_20newsgroups&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn.feature_extraction.text&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;CountVectorizer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;TfidfTransformer&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn.naive_bayes&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;MultinomialNB&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn.pipeline&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Pipeline&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn.metrics&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;accuracy_score&lt;/span&gt;

&lt;span class="c1"&gt;# Load training and test data
&lt;/span&gt;&lt;span class="n"&gt;twenty_train&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;fetch_20newsgroups&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;subset&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;train&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shuffle&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;twenty_test&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;fetch_20newsgroups&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;subset&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;test&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shuffle&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Build the pipeline: raw text → word counts → TF-IDF → Naive Bayes
&lt;/span&gt;&lt;span class="n"&gt;text_clf&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;Pipeline&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;
    &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;vect&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nc"&gt;CountVectorizer&lt;/span&gt;&lt;span class="p"&gt;()),&lt;/span&gt;
    &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;tfidf&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nc"&gt;TfidfTransformer&lt;/span&gt;&lt;span class="p"&gt;()),&lt;/span&gt;
    &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;clf&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nc"&gt;MultinomialNB&lt;/span&gt;&lt;span class="p"&gt;()),&lt;/span&gt;
&lt;span class="p"&gt;])&lt;/span&gt;

&lt;span class="c1"&gt;# Train and evaluate
&lt;/span&gt;&lt;span class="n"&gt;text_clf&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;twenty_train&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;twenty_train&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;predicted&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;text_clf&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;predict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;twenty_test&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Accuracy: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="nf"&gt;accuracy_score&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;twenty_test&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;predicted&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="o"&gt;%&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="c1"&gt;# Accuracy: 77.4%
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;With 10 lines of modelling code, we classify documents into one of 20 categories at 77.4% accuracy on unseen data. Random guessing would give 5%.&lt;/p&gt;

&lt;p&gt;Let's test it on fresh sentences the model has never seen:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;docs_new&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;OpenGL shading techniques for real-time rendering&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;The Detroit Tigers signed a new pitcher today&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;NASA launched the James Webb telescope last year&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Is there evidence for the existence of God?&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;span class="p"&gt;]&lt;/span&gt;

&lt;span class="n"&gt;predicted_new&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;text_clf&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;predict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;docs_new&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;doc&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;category&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;zip&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;docs_new&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;predicted_new&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;twenty_train&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;target_names&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;category&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="o"&gt;&amp;gt;&lt;/span&gt;&lt;span class="mi"&gt;28&lt;/span&gt;&lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;  ←  &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;doc&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;            comp.graphics  ←  OpenGL shading techniques for real-time rendering
        rec.sport.baseball  ←  The Detroit Tigers signed a new pitcher today
                 sci.space  ←  NASA launched the James Webb telescope last year
    soc.religion.christian  ←  Is there evidence for the existence of God?
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The model correctly identifies the topic of each sentence. It works by finding which words are most characteristic of each category.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F6d7j0fw1g64gfr0odiqn.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F6d7j0fw1g64gfr0odiqn.webp" alt="Confusion matrix heatmap for the Naive Bayes classifier on 20 Newsgroups. The diagonal shows correct predictions; off-diagonal cells reveal common confusions between related topics." width="800" height="707"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The confusion matrix reveals where the classifier struggles. Related categories like &lt;code&gt;comp.sys.ibm.pc.hardware&lt;/code&gt; and &lt;code&gt;comp.sys.mac.hardware&lt;/code&gt; (both about computer hardware) are frequently confused, as are &lt;code&gt;talk.religion.misc&lt;/code&gt; and &lt;code&gt;soc.religion.christian&lt;/code&gt;. These make intuitive sense: documents about Mac hardware and PC hardware use very similar vocabulary.&lt;/p&gt;

&lt;h2&gt;
  
  
  What Just Happened?
&lt;/h2&gt;

&lt;p&gt;Three components work in sequence: CountVectorizer turns text into word counts, TfidfTransformer re-weights those counts to highlight distinctive words, and MultinomialNB learns which words signal which categories.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fzh05hb9r806g3498dsss.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fzh05hb9r806g3498dsss.webp" alt="The text classification pipeline: raw text flows through tokenisation, word counting, TF-IDF weighting, and finally the Naive Bayes classifier to produce a category prediction." width="800" height="446"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3&gt;
  
  
  Step 1: Turning Text into Numbers
&lt;/h3&gt;

&lt;p&gt;A machine learning model can't read English. It needs numbers. The simplest conversion is the &lt;strong&gt;bag of words&lt;/strong&gt;: count how many times each word appears in a document, ignoring order entirely.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn.feature_extraction.text&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;CountVectorizer&lt;/span&gt;

&lt;span class="n"&gt;corpus&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;The cat sat on the mat&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;The dog sat on the log&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;The cat chased the dog&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;span class="p"&gt;]&lt;/span&gt;
&lt;span class="n"&gt;vectorizer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;CountVectorizer&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;X&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;vectorizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fit_transform&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;corpus&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;vectorizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;get_feature_names_out&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;
&lt;span class="c1"&gt;# ['cat', 'chased', 'dog', 'log', 'mat', 'on', 'sat', 'the']
&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;toarray&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;
&lt;span class="c1"&gt;# [[1, 0, 0, 0, 1, 1, 1, 2],
#  [0, 0, 1, 1, 0, 1, 1, 2],
#  [1, 1, 1, 0, 0, 0, 0, 2]]
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Each row is a document. Each column is a word from the vocabulary. The value is the word count. Notice that "the" always gets a count of 2, regardless of the document. It's everywhere, so it carries no information about which document you're looking at.&lt;/p&gt;

&lt;p&gt;On the 20 Newsgroups training set, CountVectorizer discovers around 130,000 unique tokens. Each document becomes a vector of 130,000 dimensions, mostly zeros (since any single post uses only a tiny fraction of the full vocabulary).&lt;/p&gt;

&lt;h3&gt;
  
  
  Step 2: Weighting Words That Matter
&lt;/h3&gt;

&lt;p&gt;Not all words are equally informative. Words like "the", "is", and "a" appear in every document. What we want are words that are common within a specific category but rare overall. This is exactly what &lt;strong&gt;TF-IDF&lt;/strong&gt; (Term Frequency, Inverse Document Frequency) captures.&lt;/p&gt;

&lt;p&gt;The weight for word &lt;code&gt;$t$&lt;/code&gt; in document &lt;code&gt;$d$&lt;/code&gt; is:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Ctext%257BTF-IDF%257D%28t%252C%2520d%29%2520%253D%2520%255Ctext%257BTF%257D%28t%252C%2520d%29%2520%255Ctimes%2520%255Ctext%257BIDF%257D%28t%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Ctext%257BTF-IDF%257D%28t%252C%2520d%29%2520%253D%2520%255Ctext%257BTF%257D%28t%252C%2520d%29%2520%255Ctimes%2520%255Ctext%257BIDF%257D%28t%29" alt="equation" width="354" height="25"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Where:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;TF&lt;/strong&gt; (term frequency) = how often the word appears in this document&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;IDF&lt;/strong&gt; (inverse document frequency) = &lt;code&gt;$\log\!\frac{1+N}{1+n_t}+1$&lt;/code&gt;, where &lt;code&gt;$N$&lt;/code&gt; is the total number of documents and &lt;code&gt;$n_t$&lt;/code&gt; is the number of documents containing word &lt;code&gt;$t$&lt;/code&gt;
&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;A word that appears in every document gets a low IDF, shrinking its weight. A word that appears in only a few documents gets a high IDF, amplifying its signal.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn.feature_extraction.text&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;TfidfTransformer&lt;/span&gt;

&lt;span class="n"&gt;tfidf&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;TfidfTransformer&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;X_tfidf&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tfidf&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fit_transform&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;round&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_tfidf&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;toarray&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fb4j1bjzzhbbztenq0gn1.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fb4j1bjzzhbbztenq0gn1.webp" alt="TF-IDF heatmap for the toy corpus. Common words like " width="800" height="298"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;After TF-IDF weighting, the document vectors highlight what's distinctive about each text rather than what's common across all of them.&lt;/p&gt;

&lt;h3&gt;
  
  
  Step 3: Naive Bayes Classification
&lt;/h3&gt;

&lt;p&gt;Naive Bayes applies &lt;a href="https://sesen.ai/blog/from-mle-to-bayesian-inference" rel="noopener noreferrer"&gt;Bayes' theorem&lt;/a&gt; to classify documents. Given a document with words &lt;code&gt;$w_1, w_2, \ldots, w_n$&lt;/code&gt;, it computes:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DP%28%255Ctext%257Bclass%257D%2520%255Cmid%2520w_1%252C%2520w_2%252C%2520%255Cldots%252C%2520w_n%29%2520%255Cpropto%2520P%28%255Ctext%257Bclass%257D%29%2520%255Cprod_%257Bi%253D1%257D%255E%257Bn%257D%2520P%28w_i%2520%255Cmid%2520%255Ctext%257Bclass%257D%29" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DP%28%255Ctext%257Bclass%257D%2520%255Cmid%2520w_1%252C%2520w_2%252C%2520%255Cldots%252C%2520w_n%29%2520%255Cpropto%2520P%28%255Ctext%257Bclass%257D%29%2520%255Cprod_%257Bi%253D1%257D%255E%257Bn%257D%2520P%28w_i%2520%255Cmid%2520%255Ctext%257Bclass%257D%29" alt="equation" width="547" height="68"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The "naive" part is the assumption that words are &lt;strong&gt;conditionally independent&lt;/strong&gt; given the class. This is obviously wrong: the word "neural" is far more likely to appear near "network" than near "baseball". But the simplification works remarkably well in practice because:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;We only need the ranking right&lt;/strong&gt;, not the exact probabilities. If &lt;code&gt;$P(\text{sci.space} \mid \text{doc})$&lt;/code&gt; is the highest, the prediction is correct even if the probability value itself is off.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Independence errors tend to cancel out&lt;/strong&gt; across thousands of features.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;The alternative (modelling all word dependencies) is intractable&lt;/strong&gt; for vocabularies of 130,000 words.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;The &lt;code&gt;MultinomialNB&lt;/code&gt; variant uses word counts (or TF-IDF weights) as features and models &lt;code&gt;$P(w_i \mid \text{class})$&lt;/code&gt; as a multinomial distribution. The parameters are estimated via &lt;a href="https://sesen.ai/blog/maximum-likelihood-estimation-from-scratch" rel="noopener noreferrer"&gt;maximum likelihood&lt;/a&gt;: the probability of word &lt;code&gt;$w_i$&lt;/code&gt; in class &lt;code&gt;$c$&lt;/code&gt; is simply the fraction of times &lt;code&gt;$w_i$&lt;/code&gt; appears in training documents of class &lt;code&gt;$c$&lt;/code&gt;, with Laplace smoothing to handle words never seen in training.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Pipeline: Composing the Steps
&lt;/h3&gt;

&lt;p&gt;Scikit-learn's &lt;code&gt;Pipeline&lt;/code&gt; chains these three transformations so you can treat the entire workflow as a single estimator:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;text_clf&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;Pipeline&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;
    &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;vect&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nc"&gt;CountVectorizer&lt;/span&gt;&lt;span class="p"&gt;()),&lt;/span&gt;     &lt;span class="c1"&gt;# raw text → word counts
&lt;/span&gt;    &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;tfidf&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nc"&gt;TfidfTransformer&lt;/span&gt;&lt;span class="p"&gt;()),&lt;/span&gt;   &lt;span class="c1"&gt;# word counts → TF-IDF weights
&lt;/span&gt;    &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;clf&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nc"&gt;MultinomialNB&lt;/span&gt;&lt;span class="p"&gt;()),&lt;/span&gt;        &lt;span class="c1"&gt;# TF-IDF vectors → class predictions
&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;When you call &lt;code&gt;text_clf.fit(X, y)&lt;/code&gt;, it runs &lt;code&gt;CountVectorizer.fit_transform()&lt;/code&gt;, feeds the output to &lt;code&gt;TfidfTransformer.fit_transform()&lt;/code&gt;, then passes the result to &lt;code&gt;MultinomialNB.fit()&lt;/code&gt;. At prediction time, the same chain runs in sequence. This also means you can do grid search over any parameter in the pipeline using the double-underscore naming convention (&lt;code&gt;vect__ngram_range&lt;/code&gt;, &lt;code&gt;clf__alpha&lt;/code&gt;).&lt;/p&gt;

&lt;h2&gt;
  
  
  Going Deeper
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Beating the Baseline
&lt;/h3&gt;

&lt;p&gt;Naive Bayes at 77.4% is a strong starting point, but we can improve it in three ways: removing noise (stop words), capturing phrases (bigrams), and tuning the smoothing parameter.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Stop words&lt;/strong&gt; are common words ("the", "is", "at") that carry little discriminative value. Removing them reduces noise and bumps accuracy from 77.4% to 81.7%:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;text_clf_stop&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;Pipeline&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;
    &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;vect&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nc"&gt;CountVectorizer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;stop_words&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;english&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)),&lt;/span&gt;
    &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;tfidf&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nc"&gt;TfidfTransformer&lt;/span&gt;&lt;span class="p"&gt;()),&lt;/span&gt;
    &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;clf&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nc"&gt;MultinomialNB&lt;/span&gt;&lt;span class="p"&gt;()),&lt;/span&gt;
&lt;span class="p"&gt;])&lt;/span&gt;
&lt;span class="n"&gt;text_clf_stop&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;twenty_train&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;twenty_train&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;NB + stop words: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="nf"&gt;accuracy_score&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;twenty_test&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;text_clf_stop&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;predict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;twenty_test&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="o"&gt;%&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="c1"&gt;# NB + stop words: 81.7%
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;A 4-point gain for one parameter change.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Grid search&lt;/strong&gt; systematically explores combinations of pipeline parameters. The naming convention (&lt;code&gt;vect__&lt;/code&gt;, &lt;code&gt;tfidf__&lt;/code&gt;, &lt;code&gt;clf__&lt;/code&gt;) lets you reach into any pipeline step:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn.model_selection&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;GridSearchCV&lt;/span&gt;

&lt;span class="n"&gt;parameters&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;vect__ngram_range&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="p"&gt;[(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)],&lt;/span&gt;  &lt;span class="c1"&gt;# unigrams vs unigrams+bigrams
&lt;/span&gt;    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;tfidf__use_idf&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;         &lt;span class="c1"&gt;# use IDF weighting or not
&lt;/span&gt;    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;clf__alpha&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;1e-2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mf"&gt;1e-3&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;              &lt;span class="c1"&gt;# smoothing strength
&lt;/span&gt;&lt;span class="p"&gt;}&lt;/span&gt;

&lt;span class="n"&gt;gs_clf&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;GridSearchCV&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;text_clf&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cv&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_jobs&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;gs_clf&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;twenty_train&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;twenty_train&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Best CV score: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;gs_clf&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;best_score_&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="o"&gt;%&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Best params: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;gs_clf&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;best_params_&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Test accuracy: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="nf"&gt;accuracy_score&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;twenty_test&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;gs_clf&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;predict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;twenty_test&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="o"&gt;%&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="c1"&gt;# Best CV score: 91.6%
# Best params: {'clf__alpha': 0.001, 'tfidf__use_idf': True, 'vect__ngram_range': (1, 2)}
# Test accuracy: 83.6%
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The best configuration uses bigrams (&lt;code&gt;ngram_range=(1,2)&lt;/code&gt;), IDF weighting, and weak smoothing (&lt;code&gt;alpha=0.001&lt;/code&gt;). Bigrams capture phrases like "White House" or "hard drive" that individual words miss. The 5-fold CV score (91.6%) is higher than the test accuracy (83.6%) because cross-validation evaluates on data drawn from the same distribution as training, while the test set may contain authors, topics, or writing styles not seen during training.&lt;/p&gt;

&lt;p&gt;If you've read our &lt;a href="https://sesen.ai/blog/hyperparameter-optimization-grid-random-bayesian" rel="noopener noreferrer"&gt;hyperparameter optimisation post&lt;/a&gt;, you'll recognise grid search as the brute-force baseline. With only 8 combinations to evaluate here, it's fast enough.&lt;/p&gt;

&lt;h3&gt;
  
  
  SVM: A Stronger Classifier
&lt;/h3&gt;

&lt;p&gt;Swapping Naive Bayes for a linear SVM (support vector machine) gives a larger improvement than any amount of NB tuning:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn.linear_model&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;SGDClassifier&lt;/span&gt;

&lt;span class="n"&gt;text_clf_svm&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;Pipeline&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;
    &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;vect&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nc"&gt;CountVectorizer&lt;/span&gt;&lt;span class="p"&gt;()),&lt;/span&gt;
    &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;tfidf&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nc"&gt;TfidfTransformer&lt;/span&gt;&lt;span class="p"&gt;()),&lt;/span&gt;
    &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;clf-svm&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nc"&gt;SGDClassifier&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;hinge&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;penalty&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;l2&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                               &lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e-3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_iter&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                               &lt;span class="n"&gt;random_state&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;)),&lt;/span&gt;
&lt;span class="p"&gt;])&lt;/span&gt;
&lt;span class="n"&gt;text_clf_svm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;twenty_train&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;twenty_train&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;SVM accuracy: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="nf"&gt;accuracy_score&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;twenty_test&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;text_clf_svm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;predict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;twenty_test&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="o"&gt;%&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="c1"&gt;# SVM accuracy: 82.4%
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;That's 82.4% out of the box, without any tuning. Grid search for SVM yields 83.5%, virtually identical to the tuned Naive Bayes.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fopvvu7yy9dotcwi6cpj5.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fopvvu7yy9dotcwi6cpj5.webp" alt="Accuracy comparison: Naive Bayes baseline (77.4%), NB with stop words (81.7%), SVM baseline (82.4%), NB tuned (83.6%), SVM tuned (83.5%)." width="800" height="434"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The story is clear: the biggest gains come from better feature representation (bigrams, stop word removal, IDF weighting) rather than the choice of classifier. With good features, even the "naive" model performs competitively.&lt;/p&gt;

&lt;h3&gt;
  
  
  What the Model Actually Learns
&lt;/h3&gt;

&lt;p&gt;What words does the classifier rely on? Raw class-conditional probabilities are dominated by common words like "the" and "of". To find truly discriminative features, we compare each word's log-probability within a class against its average across all classes:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;sklearn.feature_extraction.text&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;TfidfVectorizer&lt;/span&gt;

&lt;span class="n"&gt;tfidf_vect&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;TfidfVectorizer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;stop_words&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;english&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_df&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.9&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;min_df&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;X_tfidf&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tfidf_vect&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fit_transform&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;twenty_train&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;clf_disc&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;MultinomialNB&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X_tfidf&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;twenty_train&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;feature_names&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;array&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tfidf_vect&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;get_feature_names_out&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;
&lt;span class="n"&gt;log_probs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;clf_disc&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;feature_log_prob_&lt;/span&gt;
&lt;span class="n"&gt;mean_log_prob&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;log_probs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;axis&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;discriminativeness&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;log_probs&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;mean_log_prob&lt;/span&gt;

&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;category&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;enumerate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;twenty_train&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;target_names&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;top_indices&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;discriminativeness&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;argsort&lt;/span&gt;&lt;span class="p"&gt;()[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;:][::&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;category&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;, &lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;join&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;feature_names&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;top_indices&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Frw1ez38576psnc0rd6np.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Frw1ez38576psnc0rd6np.webp" alt="Most discriminative words for four categories: comp.graphics, rec.sport.baseball, sci.space, and talk.politics.mideast." width="800" height="599"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The model learns sensible patterns. &lt;code&gt;sci.space&lt;/code&gt; relies on words like "space", "orbit", and "nasa". &lt;code&gt;rec.sport.baseball&lt;/code&gt; relies on "baseball", "team", and "pitching". &lt;code&gt;talk.politics.mideast&lt;/code&gt; picks up "israel", "armenian", and "turkish". These are the words that carry the strongest evidence for each category, well beyond their background frequency.&lt;/p&gt;

&lt;h3&gt;
  
  
  Stemming: Reducing Words to Roots
&lt;/h3&gt;

&lt;p&gt;Stemming maps words to their root form ("running" to "run", "computers" to "comput"). This merges related word forms into a single feature, reducing vocabulary size:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;nltk&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;nltk.stem.snowball&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;SnowballStemmer&lt;/span&gt;

&lt;span class="n"&gt;nltk&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;download&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;punkt&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;quiet&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;stemmer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;SnowballStemmer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;english&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ignore_stopwords&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;StemmedCountVectorizer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;CountVectorizer&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;build_analyzer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;analyzer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;super&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;build_analyzer&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="k"&gt;lambda&lt;/span&gt; &lt;span class="n"&gt;doc&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;stemmer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;stem&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;w&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;w&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;analyzer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;doc&lt;/span&gt;&lt;span class="p"&gt;)]&lt;/span&gt;

&lt;span class="n"&gt;text_clf_stemmed&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;Pipeline&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;
    &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;vect&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nc"&gt;StemmedCountVectorizer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;stop_words&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;english&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)),&lt;/span&gt;
    &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;tfidf&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nc"&gt;TfidfTransformer&lt;/span&gt;&lt;span class="p"&gt;()),&lt;/span&gt;
    &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;clf&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nc"&gt;MultinomialNB&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;fit_prior&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)),&lt;/span&gt;
&lt;span class="p"&gt;])&lt;/span&gt;
&lt;span class="n"&gt;text_clf_stemmed&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;twenty_train&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;twenty_train&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;NB + stemming + stop words: &lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;
      &lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="nf"&gt;accuracy_score&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;twenty_test&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;text_clf_stemmed&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;predict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;twenty_test&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="o"&gt;%&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Stemming often gives a small additional boost. The original code uses the Snowball stemmer, a refined version of Porter's classic 1980 algorithm that handles irregular forms more gracefully.&lt;/p&gt;

&lt;h3&gt;
  
  
  When NOT to Use Bag-of-Words
&lt;/h3&gt;

&lt;p&gt;This approach has clear limitations:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Word order is lost.&lt;/strong&gt; "Dog bites man" and "man bites dog" produce the same vector. For tasks where order matters (sentiment analysis, textual entailment), you need sequence models or contextual embeddings.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Synonyms are invisible.&lt;/strong&gt; If test documents use different words for the same concepts, they won't match. Pre-trained embeddings (Word2Vec, BERT) capture semantic similarity.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Short documents suffer.&lt;/strong&gt; With only a few words, the sparse vector is too noisy for reliable classification. Transformer models handle short texts much better.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Scalability ceiling.&lt;/strong&gt; As the number of overlapping categories grows, the independence assumption becomes more costly.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;For many practical applications, TF-IDF with Naive Bayes remains hard to beat when you factor in the ratio of performance to complexity. It trains in seconds, requires no GPU, and produces interpretable results.&lt;/p&gt;

&lt;h2&gt;
  
  
  Where This Comes From
&lt;/h2&gt;

&lt;h3&gt;
  
  
  McCallum &amp;amp; Nigam (1998)
&lt;/h3&gt;

&lt;p&gt;The foundational paper for Naive Bayes text classification is &lt;strong&gt;McCallum, A. &amp;amp; Nigam, K. (1998)&lt;/strong&gt; &lt;a href="https://www.cs.cmu.edu/~knigam/papers/multinomial-aaaiws98.pdf" rel="noopener noreferrer"&gt;"A Comparison of Event Models for Naive Bayes Text Classification"&lt;/a&gt;, presented at the AAAI Workshop on Learning for Text Categorization.&lt;/p&gt;

&lt;p&gt;They compared two Naive Bayes variants for text:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Multi-variate Bernoulli&lt;/strong&gt;: each word is a binary feature (present or absent). This is &lt;code&gt;BernoulliNB&lt;/code&gt; in scikit-learn.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Multinomial&lt;/strong&gt;: each word is a count feature. This is the &lt;code&gt;MultinomialNB&lt;/code&gt; our pipeline uses.&lt;/li&gt;
&lt;/ul&gt;

&lt;blockquote&gt;
&lt;p&gt;"We find that the multinomial model is almost uniformly superior, especially for large vocabulary sizes."&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;The multinomial model works better because it uses word frequency information. A document mentioning "baseball" 15 times is stronger evidence for &lt;code&gt;rec.sport.baseball&lt;/code&gt; than one mentioning it once. The Bernoulli model discards this frequency signal entirely.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F26alprhgzok3pekyk8by.webp" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F26alprhgzok3pekyk8by.webp" alt="Comparison of the two Naive Bayes event models for text: the multivariate Bernoulli model uses binary word presence, while the multinomial model uses word counts, capturing frequency information." width="800" height="446"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3&gt;
  
  
  The Multinomial Model
&lt;/h3&gt;

&lt;p&gt;Formally, the predicted class for a document &lt;code&gt;$d$&lt;/code&gt; is:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Chat%257Bc%257D%2520%253D%2520%255Carg%255Cmax_c%2520%255Cleft%255B%255Clog%2520P%28c%29%2520%252B%2520%255Csum_%257Bi%253D1%257D%255E%257B%257CV%257C%257D%2520n_i%28d%29%2520%255Clog%2520P%28w_i%2520%255Cmid%2520c%29%255Cright%255D" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Chat%257Bc%257D%2520%253D%2520%255Carg%255Cmax_c%2520%255Cleft%255B%255Clog%2520P%28c%29%2520%252B%2520%255Csum_%257Bi%253D1%257D%255E%257B%257CV%257C%257D%2520n_i%28d%29%2520%255Clog%2520P%28w_i%2520%255Cmid%2520c%29%255Cright%255D" alt="equation" width="497" height="90"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Where:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;code&gt;$P(c)$&lt;/code&gt; is the class prior (fraction of training documents in class &lt;code&gt;$c$&lt;/code&gt;)&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;$n_i(d)$&lt;/code&gt; is the count of word &lt;code&gt;$w_i$&lt;/code&gt; in document &lt;code&gt;$d$&lt;/code&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;$P(w_i \mid c)$&lt;/code&gt; is estimated with Laplace smoothing:&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DP%28w_i%2520%255Cmid%2520c%29%2520%253D%2520%255Cfrac%257Bn_%257Bic%257D%2520%252B%2520%255Calpha%257D%257B%255Csum_%257Bj%253D1%257D%255E%257B%257CV%257C%257D%2520n_%257Bjc%257D%2520%252B%2520%255Calpha%2520%257CV%257C%257D" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257DP%28w_i%2520%255Cmid%2520c%29%2520%253D%2520%255Cfrac%257Bn_%257Bic%257D%2520%252B%2520%255Calpha%257D%257B%255Csum_%257Bj%253D1%257D%255E%257B%257CV%257C%257D%2520n_%257Bjc%257D%2520%252B%2520%255Calpha%2520%257CV%257C%257D" alt="equation" width="301" height="65"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The smoothing parameter &lt;code&gt;$\alpha$&lt;/code&gt; prevents zero probabilities for words that never appeared in a particular class during training. Our grid search found &lt;code&gt;$\alpha = 0.001$&lt;/code&gt; optimal, meaning the model trusts the training data more and smooths less aggressively than the default &lt;code&gt;$\alpha = 1.0$&lt;/code&gt;.&lt;/p&gt;

&lt;h3&gt;
  
  
  TF-IDF: Salton &amp;amp; Buckley (1988)
&lt;/h3&gt;

&lt;p&gt;TF-IDF was formalised by &lt;strong&gt;Salton, G. &amp;amp; Buckley, C. (1988)&lt;/strong&gt; &lt;a href="https://doi.org/10.1016/0306-4573(88)90021-0" rel="noopener noreferrer"&gt;"Term-weighting approaches in automatic text retrieval"&lt;/a&gt;, &lt;em&gt;Information Processing &amp;amp; Management&lt;/em&gt;. The core idea predates this work: Sparck Jones proposed inverse document frequency in 1972.&lt;/p&gt;

&lt;p&gt;Scikit-learn's variant uses:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Ctext%257BIDF%257D%28t%29%2520%253D%2520%255Clog%255C%21%255Cfrac%257B1%2520%252B%2520N%257D%257B1%2520%252B%2520n_t%257D%2520%252B%25201" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Flatex.codecogs.com%2Fpng.image%3F%255Cdpi%257B150%257D%255Ctext%257BIDF%257D%28t%29%2520%253D%2520%255Clog%255C%21%255Cfrac%257B1%2520%252B%2520N%257D%257B1%2520%252B%2520n_t%257D%2520%252B%25201" alt="equation" width="246" height="55"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The "+1" terms prevent division by zero and ensure no word gets zero weight. After computing TF-IDF, each document vector is L2-normalised to unit length.&lt;/p&gt;

&lt;h3&gt;
  
  
  Historical Context
&lt;/h3&gt;

&lt;p&gt;Text classification has a long lineage:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Maron (1961)&lt;/strong&gt; — First automatic text classification using probabilistic indexing&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Salton (1971)&lt;/strong&gt; — The SMART retrieval system, introducing many weighting schemes&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Sparck Jones (1972)&lt;/strong&gt; — Inverse document frequency&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Lewis (1998)&lt;/strong&gt; — The Reuters benchmark that standardised evaluation&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Joachims (1998)&lt;/strong&gt; — Showed SVMs outperform NB on text (our results confirm this: 82.4% vs 77.4%)&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;McCallum &amp;amp; Nigam (1998)&lt;/strong&gt; — Systematic comparison of NB event models&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;Today, transformer-based models (BERT, GPT) dominate text classification benchmarks. But TF-IDF with Naive Bayes remains the standard baseline for its speed, interpretability, and surprising competitiveness.&lt;/p&gt;

&lt;h3&gt;
  
  
  Further Reading
&lt;/h3&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://www.cs.cmu.edu/~knigam/papers/multinomial-aaaiws98.pdf" rel="noopener noreferrer"&gt;McCallum &amp;amp; Nigam (1998)&lt;/a&gt; — Multinomial vs Bernoulli NB for text&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://doi.org/10.1016/0306-4573(88)90021-0" rel="noopener noreferrer"&gt;Salton &amp;amp; Buckley (1988)&lt;/a&gt; — Systematic evaluation of TF-IDF variants&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://link.springer.com/chapter/10.1007/BFb0026683" rel="noopener noreferrer"&gt;Joachims (1998)&lt;/a&gt; — Text categorisation with SVMs&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Manning, Raghavan &amp;amp; Schütze (2008)&lt;/strong&gt; &lt;a href="https://nlp.stanford.edu/IR-book/" rel="noopener noreferrer"&gt;&lt;em&gt;Introduction to Information Retrieval&lt;/em&gt;&lt;/a&gt; — Free textbook covering TF-IDF, NB, and SVM for text in depth&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Try It Yourself
&lt;/h2&gt;

&lt;p&gt;The &lt;a href="https://colab.research.google.com/github/zhubarb/sesen_ai_ml_tutorials/blob/main/notebooks/nlp/tfidf_naive_bayes.ipynb" rel="noopener noreferrer"&gt;interactive notebook&lt;/a&gt; includes exercises:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Subset classification&lt;/strong&gt; — Use only 4 categories (&lt;code&gt;comp.graphics&lt;/code&gt;, &lt;code&gt;rec.sport.baseball&lt;/code&gt;, &lt;code&gt;sci.space&lt;/code&gt;, &lt;code&gt;talk.politics.mideast&lt;/code&gt;). How much does accuracy improve with fewer, more distinct categories?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Feature engineering&lt;/strong&gt; — Add &lt;code&gt;min_df=5&lt;/code&gt; and &lt;code&gt;max_df=0.5&lt;/code&gt; to &lt;code&gt;CountVectorizer&lt;/code&gt; to trim rare and ubiquitous words. How does this affect accuracy and vocabulary size?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Bernoulli vs Multinomial&lt;/strong&gt; — Replace &lt;code&gt;MultinomialNB&lt;/code&gt; with &lt;code&gt;BernoulliNB&lt;/code&gt;. Does the McCallum &amp;amp; Nigam finding hold on this dataset?&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Beyond bag-of-words&lt;/strong&gt; — Use &lt;code&gt;TfidfVectorizer&lt;/code&gt; with &lt;code&gt;sublinear_tf=True&lt;/code&gt; and character n-grams (&lt;code&gt;analyzer='char_wb'&lt;/code&gt;, &lt;code&gt;ngram_range=(3,5)&lt;/code&gt;). Character n-grams capture morphological patterns that word-level features miss.&lt;/li&gt;
&lt;/ol&gt;

&lt;h2&gt;
  
  
  Interactive Tools
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/classification-metrics-calculator" rel="noopener noreferrer"&gt;Classification Metrics Calculator&lt;/a&gt; — Compute precision, recall, F1, and other metrics from your own confusion matrix&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/bayes-theorem-calculator" rel="noopener noreferrer"&gt;Bayes' Theorem Calculator&lt;/a&gt; — Explore the Bayesian reasoning that underpins Naive Bayes classification&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Related Posts
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/maximum-likelihood-estimation-from-scratch" rel="noopener noreferrer"&gt;Maximum Likelihood Estimation from Scratch&lt;/a&gt; — The estimation method behind Naive Bayes parameter learning&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/from-mle-to-bayesian-inference" rel="noopener noreferrer"&gt;From MLE to Bayesian Inference&lt;/a&gt; — The Bayes' theorem foundation that powers Naive Bayes classification&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://sesen.ai/blog/hyperparameter-optimization-grid-random-bayesian" rel="noopener noreferrer"&gt;Hyperparameter Optimization: Grid vs Random vs Bayesian&lt;/a&gt; — A deeper look at grid search and smarter alternatives&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Frequently Asked Questions
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Why is Naive Bayes called "naive"?
&lt;/h3&gt;

&lt;p&gt;The "naive" refers to the conditional independence assumption: the model assumes that each word in a document is independent of every other word, given the class. This is clearly wrong (e.g. "neural" and "network" tend to co-occur), but it works surprisingly well in practice because classification only requires getting the ranking of class probabilities right, not the exact values. Independence errors tend to cancel out across thousands of features.&lt;/p&gt;

&lt;h3&gt;
  
  
  What is the difference between TF-IDF and raw word counts?
&lt;/h3&gt;

&lt;p&gt;Raw word counts treat all words equally, so common words like "the" and "is" dominate the representation despite carrying no discriminative information. TF-IDF re-weights each word by how rare it is across the entire corpus. Words that appear in many documents get downweighted, while words distinctive to a few documents get amplified. This makes the representation much more informative for classification.&lt;/p&gt;

&lt;h3&gt;
  
  
  When should I use Naive Bayes instead of a transformer model like BERT?
&lt;/h3&gt;

&lt;p&gt;Naive Bayes with TF-IDF is an excellent choice when you need fast training (seconds, not hours), interpretability (you can inspect which words drive predictions), or when labelled data is limited. It also requires no GPU. For tasks where word order matters (sentiment analysis, entailment) or where you need state-of-the-art accuracy on competitive benchmarks, transformer models will outperform it significantly.&lt;/p&gt;

&lt;h3&gt;
  
  
  What does the smoothing parameter alpha do in MultinomialNB?
&lt;/h3&gt;

&lt;p&gt;Alpha controls Laplace smoothing, which prevents zero probabilities for words that never appeared in a particular class during training. With alpha = 1.0 (the default), the model adds a pseudocount of 1 to every word-class combination. Smaller values like 0.001 trust the training data more and smooth less aggressively. The optimal value depends on your dataset and can be found through cross-validation.&lt;/p&gt;

&lt;h3&gt;
  
  
  Why does the model confuse related categories like PC hardware and Mac hardware?
&lt;/h3&gt;

&lt;p&gt;The bag-of-words representation captures which words appear in a document but not the subtle semantic differences between closely related topics. Categories like PC hardware and Mac hardware share a large portion of their vocabulary (words like "drive", "memory", "board", "system"). The model can only distinguish them by the few words unique to each category, which may not always be present in a given document.&lt;/p&gt;

&lt;h3&gt;
  
  
  Can TF-IDF handle languages other than English?
&lt;/h3&gt;

&lt;p&gt;Yes. TF-IDF is language-agnostic at its core since it operates on tokens, not linguistic structures. However, you may need to adjust tokenisation for languages without clear word boundaries (e.g. Chinese or Japanese) and consider language-specific stop word lists. Stemming and lemmatisation tools are also language-dependent, so you would need appropriate resources for your target language.&lt;/p&gt;

</description>
      <category>supervisedlearning</category>
      <category>discriminative</category>
      <category>probabilistic</category>
    </item>
  </channel>
</rss>
