DEV Community

Cover image for Bayesian Survival Analysis with PyMC: Modelling Customer Churn
Berkan Sesen
Berkan Sesen

Posted on • Originally published at sesen.ai

Bayesian Survival Analysis with PyMC: Modelling Customer Churn

Every subscription business lives or dies by churn. Whether it is a B2B SaaS platform tracking annual contracts or a consumer app watching monthly renewals, the question is the same: how long will this customer stay? The data seems straightforward. Some subscribers cancelled after a month, others after a year. But a large share of customers are still active. They have not churned yet, and you do not know when, or whether, they will.

A colleague suggested dropping them from the analysis. That felt wrong, and it is: ignoring active customers biases your model toward shorter lifetimes, because you only learn from the people who already left.

The problem has a name: right-censoring. An active customer who signed up 8 months ago tells you something valuable: they survived at least 8 months. You don't know when (or whether) they'll churn, but that lower bound is real information.

Survival analysis handles censoring properly. In our previous post, we built hierarchical models in PyMC for grouped regression. This post extends that toolkit with a new ingredient: the ability to learn from incomplete observations.

By the end, you'll build a Bayesian accelerated failure time (AFT) model in PyMC, handle right-censored data with pm.Potential, compare Weibull and Log-Logistic distributions, and plot individual survival curves for different customer profiles.

Let's Build It

First, let's see the model in action. Click the badge below to open the full interactive notebook:

Open In Colab

We'll generate synthetic churn data for 1,000 customers, fit a Weibull AFT model, and plot survival curves.

Survival curves for three customer profiles building up as MCMC samples accumulate. Early frames show scattered, uncertain curves; later frames converge to smooth, separated survival functions for high-value, average, and at-risk customers.

import numpy as np
import pandas as pd
import pymc as pm
import pytensor.tensor as pt
import arviz as az
import matplotlib.pyplot as plt

np.random.seed(42)

# Generate synthetic churn data: 1,000 customers observed over 24 months
N = 1000
monthly_spend = np.random.normal(100, 30, N).clip(20, 250)
support_tickets = np.random.poisson(3, N).astype(float)

# Standardise covariates
spend_std = (monthly_spend - 100) / 30
tickets_std = (support_tickets - 3) / 2

# True AFT parameters (Gumbel / log-Weibull parameterisation)
true_alpha = np.array([2.5, 0.4, -0.3])  # intercept, spend, tickets
true_s = 0.6

# True log-time: Y = eta + s * W, where W ~ Gumbel(0,1)
eta_true = true_alpha[0] + true_alpha[1] * spend_std + true_alpha[2] * tickets_std
log_time_true = eta_true + true_s * np.random.gumbel(0, 1, N)
time_true = np.exp(log_time_true)

# Administrative censoring at 24 months
observation_window = 24.0
observed_time = np.minimum(time_true, observation_window)
censored = time_true > observation_window  # True = still active
log_observed_time = np.log(observed_time)

print(f"Total customers: {N}")
print(f"Churned: {(~censored).sum()} ({(~censored).mean():.0%})")
print(f"Still active (censored): {censored.sum()} ({censored.mean():.0%})")
Enter fullscreen mode Exit fullscreen mode
Total customers: 1000
Churned: 664 (66%)
Still active (censored): 336 (34%)
Enter fullscreen mode Exit fullscreen mode

Before fitting the Bayesian model, let's look at the empirical survival curve using the Kaplan-Meier estimator. This non-parametric method handles censoring correctly by adjusting the risk set at each event time:

# Kaplan-Meier estimator (manual, no extra dependencies)
order = np.argsort(observed_time)
times_sorted = observed_time[order]
events_sorted = (~censored)[order].astype(int)

km_times = [0.0]
km_survival = [1.0]
n_at_risk = N

for t, event in zip(times_sorted, events_sorted):
    if event:
        km_survival.append(km_survival[-1] * (1 - 1 / n_at_risk))
        km_times.append(t)
    n_at_risk -= 1

fig, ax = plt.subplots(figsize=(8, 4))
ax.step(km_times, km_survival, where='post', color='#2196F3', lw=2)
ax.set_xlabel('Months since signup')
ax.set_ylabel('Survival probability')
ax.set_title('Kaplan-Meier Survival Curve')
ax.set_xlim(0, 25)
ax.set_ylim(0, 1.05)
plt.tight_layout()
Enter fullscreen mode Exit fullscreen mode

Kaplan-Meier survival curve for the synthetic churn data. The curve drops steadily over 24 months with censoring tick marks visible. About 34% of customers survive past the 24-month observation window.

Now let's fit the Weibull AFT model. The key insight: if $T \sim \text{Weibull}$, then $Y = \log T$ follows a Gumbel distribution. So we model log-time with a Gumbel likelihood, which lets us write the linear predictor naturally:

def gumbel_log_sf(y, mu, sigma):
    """Log survival function of the Gumbel distribution."""
    return pt.log1p(-pt.exp(-pt.exp(-(y - mu) / sigma)))

with pm.Model() as weibull_aft:
    # Location coefficients (priors match original code: Normal(0, 2))
    alpha = pm.Normal('alpha', mu=0, sigma=2, shape=3)

    # Scale parameter (must be positive)
    log_s = pm.Normal('log_s', mu=0, sigma=1)
    s = pm.Deterministic('s', pm.math.exp(log_s))

    # Linear predictor for log-time
    eta = alpha[0] + alpha[1] * spend_std + alpha[2] * tickets_std

    # Uncensored customers: standard Gumbel likelihood
    y_obs = pm.Gumbel('y_obs', mu=eta[~censored], beta=s,
                       observed=log_observed_time[~censored])

    # Censored customers: survival function via pm.Potential
    y_cens = pm.Potential('y_cens',
        gumbel_log_sf(log_observed_time[censored], eta[censored], s))

    # Sample the posterior
    trace = pm.sample(1000, tune=2000, cores=4, chains=4,
                      random_seed=42, target_accept=0.9)

print(az.summary(trace, var_names=['alpha', 's']))
Enter fullscreen mode Exit fullscreen mode

You just fit a Bayesian survival model that properly handles censored customers. The alpha coefficients tell you how each covariate affects time-to-churn: positive means longer survival, negative means faster churn. And unlike a point estimate from maximum likelihood, you get full posterior distributions over every parameter.

What Just Happened?

Right-Censoring: Learning from Incomplete Data

The 336 active customers in our data didn't churn during the 24-month observation window. For each one, we know they survived at least 24 months, but not how much longer they'll stay. This is right-censoring: the true event time is somewhere to the right of what we observed.

Timeline diagram showing 8 example customers. Five lines end with a red X (churn event) at various times. Three lines extend to the 24-month boundary and end with a green arrow (still active, censored). The observation window is shaded.

Standard regression would force you to either drop censored customers (biasing estimates downward) or code them as churning at 24 months (also biased). Survival analysis treats the two types of observation differently in the likelihood.

For a churned customer at time $t_i$, the likelihood contribution is the probability density $f(t_i)$: we observed this exact event time. For a censored customer observed until time $c_i$, the contribution is the survival probability $S(c_i) = P(T > c_i)$: all we know is they lasted at least this long.

The total log-likelihood combines both pieces:

equation

This is exactly how our PyMC model works. The pm.Gumbel line handles the first sum (uncensored density). The pm.Potential line handles the second sum (censored survival).

Why Gumbel? The Weibull-Gumbel Connection

The Weibull distribution is the workhorse of survival analysis because it models flexible hazard rates: increasing, decreasing, or constant over time. But working with the Weibull directly is numerically awkward for regression.

Here's the trick. If $T \sim \text{Weibull}(k, \lambda)$, then $Y = \log T$ follows a Gumbel distribution:

equation

where $\mu = \log \lambda$ is the location and $\sigma = 1/k$ is the scale. This is the accelerated failure time (AFT) formulation: covariates shift $\mu$, effectively accelerating or decelerating time. We write the linear predictor as:

equation

A positive $\alpha_1$ means higher spending shifts log-time to the right (longer survival). A negative $\alpha_2$ means more support tickets shift it left (faster churn). The coefficients have a direct interpretation: a one-unit increase in $x_j$ multiplies the median survival time by $\exp(\alpha_j)$.

pm.Potential: Telling PyMC About Partial Information

In our hierarchical regression post, every observation contributed a full likelihood term through pm.Normal(..., observed=y). Censored observations are different: they don't have a fully observed outcome. They only contribute through the survival function.

pm.Potential('name', value) adds value directly to the model's log-posterior. For censored data, we pass the log-survival probability:

y_cens = pm.Potential('y_cens',
    gumbel_log_sf(log_observed_time[censored], eta[censored], s))
Enter fullscreen mode Exit fullscreen mode

Think of it this way. For a churned customer, we say "we observed them leave at time $t$" (standard likelihood). For an active customer, we say "all we know is they're still here after $c$ months" (survival function).

MCMC Diagnostics

Before trusting the results, verify the sampler converged:

az.plot_trace(trace, var_names=['alpha', 's'])
plt.tight_layout()
Enter fullscreen mode Exit fullscreen mode

ArviZ trace plots for the Weibull AFT model. Top row: alpha posteriors (intercept near 2.5, spend coefficient near 0.4, tickets coefficient near −0.3) with MCMC traces. Bottom row: scale parameter s centred near 0.63. All four chains mix well with stable, overlapping traces.

Check the same three diagnostics we covered in the hierarchical regression post: chains should look like "hairy caterpillars" (good mixing), R-hat below 1.01 (convergence), and effective sample size above 400 per chain (low autocorrelation).

Survival Curves from the Posterior

The payoff of a Bayesian AFT model is individual survival curves with uncertainty bands. For any customer profile, we compute the survival probability at each time point across all posterior samples:

t_grid = np.linspace(0.5, 36, 200)
log_t_grid = np.log(t_grid)

# Extract posterior samples
alpha_post = trace.posterior['alpha'].values.reshape(-1, 3)
s_post = trace.posterior['s'].values.flatten()

profiles = {
    'High-value (spend +1.5σ, tickets −1σ)': (1.5, -1.0, '#2196F3'),
    'Average customer':                        (0.0,  0.0, '#FF9800'),
    'At-risk (spend −1.5σ, tickets +2σ)':     (-1.5,  2.0, '#F44336'),
}

fig, ax = plt.subplots(figsize=(8, 5))
for label, (sp, tk, color) in profiles.items():
    eta_post = alpha_post[:, 0] + alpha_post[:, 1] * sp + alpha_post[:, 2] * tk
    survival = np.zeros((len(eta_post), len(t_grid)))
    for i in range(len(eta_post)):
        z = (log_t_grid - eta_post[i]) / s_post[i]
        survival[i] = 1 - np.exp(-np.exp(-z))
    mean_surv = survival.mean(axis=0)
    lower = np.percentile(survival, 3, axis=0)
    upper = np.percentile(survival, 97, axis=0)
    ax.plot(t_grid, mean_surv, color=color, lw=2, label=label)
    ax.fill_between(t_grid, lower, upper, color=color, alpha=0.15)

ax.set_xlabel('Months since signup')
ax.set_ylabel('Survival probability')
ax.set_title('Predicted Survival Curves by Customer Profile')
ax.legend(loc='upper right', fontsize=9)
ax.set_xlim(0, 36)
ax.set_ylim(0, 1.05)
plt.tight_layout()
Enter fullscreen mode Exit fullscreen mode

Survival curves for three customer profiles with 94% HDI bands. The high-value customer (blue) stays above 85% survival at 24 months. The average customer (orange) crosses 50% around 13 months. The at-risk customer (red) drops below 20% by month 10.

Each curve shows the model's predicted probability that a customer with those characteristics survives beyond a given time. The high-value customer has a much flatter curve: their predicted median lifetime exceeds 36 months. The at-risk customer (low spend, many support tickets) has a steep drop-off with a median around 5 months.

Notice the uncertainty bands widen at longer times, especially for the at-risk profile. Fewer customers with those characteristics survive that long, so the model has less data to constrain the prediction.

Going Deeper

Covariates in the Scale Too

The model above uses a constant scale parameter $s$ for all customers. The original code I adapted goes further by making the scale covariate-dependent:

equation

This means the shape of the Weibull hazard varies across customers. A customer might have both a longer expected lifetime (larger $\eta$) and more predictable survival (smaller $s$). In PyMC:

with pm.Model() as weibull_aft_hetero:
    # Location coefficients
    alpha = pm.Normal('alpha', mu=0, sigma=2, shape=3)
    # Scale coefficients (matching original code's rho priors)
    rho = pm.Normal('rho', mu=0, sigma=2, shape=3)

    eta = alpha[0] + alpha[1] * spend_std + alpha[2] * tickets_std
    s = pm.math.exp(rho[0] + rho[1] * spend_std + rho[2] * tickets_std)

    y_obs = pm.Gumbel('y_obs', mu=eta[~censored], beta=s[~censored],
                       observed=log_observed_time[~censored])
    y_cens = pm.Potential('y_cens',
        gumbel_log_sf(log_observed_time[censored], eta[censored], s[censored]))

    trace_hetero = pm.sample(1000, tune=2000, cores=4, chains=4,
                             random_seed=42, target_accept=0.9)
Enter fullscreen mode Exit fullscreen mode

This is faithful to the aft_model_factory_explicit function in the original code, which uses separate rho_interc, rho_coeff1, rho_coeff2 parameters for the Gumbel scale. The exp link ensures $s_i > 0$ for every customer.

Weibull vs Log-Logistic: Which Tail Shape?

The Weibull model assumes the hazard rate is monotonic: always increasing, always decreasing, or constant. But some churn patterns are non-monotonic. New users might have high churn risk initially (they haven't found value yet), which drops as they engage, then rises again as they outgrow the product.

The Log-Logistic AFT model handles this. In log-time, the Log-Logistic corresponds to a Logistic distribution, just as the Weibull corresponds to a Gumbel. The swap is straightforward:

def logistic_log_sf(y, mu, sigma):
    """Log survival function of the Logistic distribution."""
    return -pt.softplus((y - mu) / sigma)

with pm.Model() as loglogistic_aft:
    alpha = pm.Normal('alpha', mu=0, sigma=2, shape=3)
    log_s = pm.Normal('log_s', mu=0, sigma=1)
    s = pm.Deterministic('s', pm.math.exp(log_s))

    eta = alpha[0] + alpha[1] * spend_std + alpha[2] * tickets_std

    y_obs = pm.Logistic('y_obs', mu=eta[~censored], s=s,
                         observed=log_observed_time[~censored])
    y_cens = pm.Potential('y_cens',
        logistic_log_sf(log_observed_time[censored], eta[censored], s))

    trace_ll = pm.sample(1000, tune=2000, cores=4, chains=4,
                         random_seed=42, target_accept=0.9)
Enter fullscreen mode Exit fullscreen mode

Two-panel comparison of Weibull (left) and Log-Logistic (right) survival curves for the average customer. The Weibull curve decays smoothly following a stretched exponential. The Log-Logistic curve has a heavier tail, decaying more slowly at longer times.

Compare the two models using LOO-CV (leave-one-out cross-validation) with ArviZ:

weibull_loo = az.loo(trace)
ll_loo = az.loo(trace_ll)
print(az.compare({'Weibull': trace, 'Log-Logistic': trace_ll}))
Enter fullscreen mode Exit fullscreen mode

Since our synthetic data was generated from a Weibull distribution, the Weibull model should win. On real data, the comparison often reveals which tail shape better captures your customers' churn dynamics.

The Cox Proportional Hazards Alternative

Survival analysis has a dominant semi-parametric approach: the Cox proportional hazards (PH) model. It doesn't assume a distribution for the baseline hazard, only that covariates multiply the hazard by a constant factor. This flexibility made it ubiquitous in clinical trials.

So why choose a parametric Bayesian AFT model? Three reasons:

  1. Full predictive distributions. The Cox model gives hazard ratios, but producing survival curves requires additional estimation of the baseline hazard. Our Bayesian AFT model gives survival curves with uncertainty bands directly from the posterior.
  2. Small samples and heavy censoring. With many active customers, the Cox model's partial likelihood can be imprecise. Bayesian priors stabilise estimates, especially for rare covariates. This is the same principle of "borrowing strength" we explored in the hierarchical regression post.
  3. Natural extension. PyMC models compose freely. Adding group structure (churn by subscription tier), time-varying covariates, or custom likelihoods is straightforward. The next post in this series demonstrates exactly this with a one-inflated Beta regression.

Flow diagram showing the AFT model structure. Covariates (monthly spend, support tickets) feed into two linear predictors: one for the location parameter eta and one for the scale parameter s. These combine into a Gumbel distribution for log-time, which maps to a Weibull distribution for actual survival time. Censored and uncensored paths split at the likelihood.

When NOT to Use Bayesian AFT

If the proportional hazards assumption holds and your dataset is large (tens of thousands of events), the Cox model is faster and assumption-lighter. If you have time-varying covariates that change during a customer's lifetime (e.g., monthly usage patterns), the standard AFT formulation doesn't handle them naturally; you'd need a piecewise approach or a joint model.

Computational cost matters too. Our 1,000-customer model samples in a few minutes, but production datasets with millions of rows would require approximations like variational inference or mini-batch MCMC.

Where This Comes From

Cox (1972): Proportional Hazards

The modern era of survival analysis began with David Cox's 1972 paper "Regression Models and Life-Tables." Cox introduced the proportional hazards model:

equation

where $h_0(t)$ is an unspecified baseline hazard. The genius was leaving $h_0$ unspecified and estimating $\boldsymbol{\beta}$ through the partial likelihood, which depends only on the order of events, not their exact times. This paper has been cited over 65,000 times and remains the most-used method in clinical trials.

"The important practical point is that [the partial likelihood] does not require specification of $h_0(t)$." (Cox, 1972)

Our AFT model takes a different path: we specify a distribution (Weibull or Log-Logistic), which enables direct time predictions. This parametric assumption is both a strength (more powerful inference when correct) and a weakness (biased inference when wrong).

Buckley and James (1979): Accelerated Failure Time

The AFT framework was formalised by Miles Buckley and Ian James in 1979. Their key insight was that the AFT model has a direct linear regression interpretation:

equation

where $\epsilon_i$ follows a known distribution (Gumbel for Weibull, Logistic for Log-Logistic). The coefficients $\alpha_j$ have a clean meaning: a one-unit increase in $x_j$ multiplies the median survival time by $\exp(\alpha_j)$. This is why it's called "accelerated failure time": covariates speed up or slow down the passage of time.

Wei (1992): AFT as an Alternative

L. J. Wei's 1992 paper "The Accelerated Failure Time Model: A Useful Alternative to the Cox Regression Model in Survival Analysis" made the case for AFT models as a practical complement to Cox PH. Wei showed that AFT models are more robust to omitted covariates and provide more interpretable effect sizes.

"When the acceleration factor is constant over time, the AFT model provides a simple and clinically meaningful summary of the survival experience." (Wei, 1992)

Handling Censoring in PyMC

The pm.Potential approach for censored data follows directly from the likelihood factorisation. For a dataset with observed and censored outcomes:

equation

Taking logs, the uncensored terms give the standard log-likelihood (handled by pm.Gumbel or pm.Logistic). The censored terms give log-survival values (handled by pm.Potential). This pattern appears throughout the PyMC survival analysis examples and extends naturally to interval censoring and left censoring by swapping the survival function for the appropriate probability term.

Further Reading

  • The proportional hazards model: Cox, D. R. (1972). "Regression Models and Life-Tables." Journal of the Royal Statistical Society: Series B, 34(2), 187-220.
  • The AFT framework: Buckley, J. & James, I. (1979). "Linear Regression with Censored Data." Biometrika, 66(3), 429-436.
  • AFT as a Cox alternative: Wei, L. J. (1992). "The Accelerated Failure Time Model." Statistics in Medicine, 11(14-15), 1871-1879.
  • The standard reference: Kalbfleisch, J. D. & Prentice, R. L. (2002). The Statistical Analysis of Failure Time Data, 2nd ed. Wiley.
  • PyMC survival example: Weibull AFT notebook
  • Previous in this series: Hierarchical Bayesian Regression with PyMC, which introduces PyMC, partial pooling, and ArviZ diagnostics.
  • Next in this series: Custom likelihoods in PyMC, where we build a one-inflated Beta regression for bounded outcome data.

Interactive Tools

Related Posts

Frequently Asked Questions

What is right-censoring and why does it matter?

Right-censoring occurs when you know a subject survived at least until a certain time, but not the actual event time. In churn analysis, active customers are right-censored because they have not yet churned. Ignoring them biases your model toward shorter lifetimes, since you only learn from customers who already left. Survival analysis handles censoring properly by using the survival function for these partial observations.

What is the difference between the Cox model and an AFT model?

The Cox proportional hazards model is semi-parametric: it leaves the baseline hazard unspecified and estimates how covariates multiply the hazard rate. The accelerated failure time (AFT) model is fully parametric: it assumes a specific distribution (such as Weibull) and models how covariates accelerate or decelerate time to event. AFT coefficients have a direct interpretation as multipliers on median survival time, while Cox coefficients are hazard ratios.

What does pm.Potential do in PyMC?

pm.Potential adds an arbitrary log-probability term directly to the model's log-posterior. For censored observations, there is no fully observed outcome to pass to a standard likelihood. Instead, you compute the log-survival probability and add it via pm.Potential, telling PyMC that these customers survived at least this long without specifying when they will actually churn.

How do I choose between Weibull and Log-Logistic distributions?

Use Weibull when you expect the hazard rate to be monotonic, either always increasing, always decreasing, or constant over time. Use Log-Logistic when the hazard may be non-monotonic, such as high initial churn that drops as users engage and then rises again later. You can compare the two formally using LOO-CV (leave-one-out cross-validation) in ArviZ.

How many customers do I need for a Bayesian survival model?

Bayesian models can work with surprisingly small datasets because priors regularise the estimates, but a practical minimum is a few hundred observations with at least 50 to 100 uncensored events. With heavy censoring (over 80% still active), the model has less information about event times, so you may need a larger sample or more informative priors to get precise estimates.

Can I add time-varying covariates to a Bayesian AFT model?

The standard AFT formulation assumes covariates are fixed at baseline and does not naturally handle features that change during a customer's lifetime, such as monthly usage patterns. For time-varying covariates, you would need a piecewise AFT approach that splits each customer's timeline into intervals, or a joint model that links the longitudinal covariate process with the survival outcome.

Top comments (0)