DEV Community

Cover image for Customer Lifetime Value (CLV) Prediction with Machine Learning
Maureen Muthoni
Maureen Muthoni

Posted on

Customer Lifetime Value (CLV) Prediction with Machine Learning

Introduction

Customer acquisition is expensive. But do you know which customers will actually generate long term revenue? That’s where Customer Lifetime Value (CLV) comes in.

Instead of focusing on one-off transactions, CLV estimates the total revenue a business expects from a customer over their entire relationship.

In this project, I built an end-to-end CLV prediction model and then deployed it as a production ready API.

In this article, we’ll cover:

  • Business problem
  • Data preprocessing
  • Model development
  • Model evaluation
  • Model deployment with FastAPI
  • Production-ready setup

The Business Problem

Businesses want to answer:

  • Which customers are most valuable?
  • Who should receive retention incentives?
  • Where should marketing budgets be allocated?

Predicting CLV helps with:

  • Customer segmentation
  • Revenue forecasting
  • Budget optimization
  • Retention strategies

This is a regression problem since CLV is a continuous value.

Step 1: Data Preprocessing

The dataset includes:

  • Purchase frequency
  • Recency
  • Average transaction value
  • Tenure
  • Demographic features

Data Preparation

Before training any model, we need to separate our features from the target variable. In this case, CLV is what we're trying to predict, and everything else in the dataset serves as input:

x = df.drop('CLV', axis=1)
y = df['CLV']
Enter fullscreen mode Exit fullscreen mode

We also check for missing values:

x.isnull().sum()
Enter fullscreen mode Exit fullscreen mode

Clean data is non-negotiable. Missing values can silently corrupt a model's performance if left unaddressed.

Splitting the Dataset

We divide the data into training and testing sets 80% for training and 20% for evaluating performance on unseen data:

from sklearn.model_selection import train_test_split

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)
Enter fullscreen mode Exit fullscreen mode

Setting random_state=42 ensures reproducibility, so results remain consistent across runs.

Step 2: Model development

Linear Regression
We start with linear regression, a simple but interpretable baseline. It assumes a linear relationship between the features and the target, making it fast to train and easy to explain to stakeholders.

from sklearn.linear_model import LinearRegression

Linear = LinearRegression()
Linear.fit(x_train, y_train)
Predictions = Linear.predict(x_test)
Enter fullscreen mode Exit fullscreen mode

Random Forest Regressor
Next, we train a Random Forest an ensemble method that builds 200 decision trees and averages their predictions. This approach is more robust to non-linear patterns in the data and tends to outperform linear models on complex real world datasets.

from sklearn.ensemble import RandomForestRegressor

rf = RandomForestRegressor(n_estimators=200, random_state=42)
rf.fit(x_train, y_train)
random_prediction = rf.predict(x_test)
Enter fullscreen mode Exit fullscreen mode

Step 3: Model evaluation

We evaluate both models using Root Mean Squared Error (RMSE) and R² Score. RMSE tells us the average prediction error in the same units as CLV, while R² tells us how much of the variance in CLV our model explains (1.0 = perfect, 0 = no better than guessing the mean).

from sklearn.metrics import mean_squared_error, r2_score
from math import sqrt

RMSE_linear = sqrt(mean_squared_error(y_test, Predictions))
r2_linear = r2_score(y_test, Predictions)

RMSE_tree = sqrt(mean_squared_error(y_test, random_prediction))
r2_tree = r2_score(y_test, random_prediction)

print(f'RMSE_linear: {RMSE_linear}')
print(f'r2_linear:   {r2_linear}')
print(f'RMSE_tree:   {RMSE_tree}')
print(f'r2_tree:     {r2_tree}')
Enter fullscreen mode Exit fullscreen mode

In most real-world CLV scenarios, the Random Forest will outperform Linear Regression due to its ability to capture complex, non-linear relationships between customer features and lifetime value.

Saving the Model
Once we're satisfied with model performance, we persist the trained model and feature schema using joblib. This makes reloading the model later without retraining straightforward:

import joblib

model = joblib.load('CLV_model.joblib')
feature_name = joblib.load('modelfeatures.joblib')
Enter fullscreen mode Exit fullscreen mode

Saving the feature set alongside the model is a great practice. It documents exactly what columns and structure the model expects at inference time, which prevents subtle bugs when deploying.

Step 4: Model deployment with FastAPI

Training a model is only half the work. To put it into production, you need an API that other systems can call. Here's how to build a simple REST endpoint using FastAPI:

1. Install Dependencies

pip install fastapi uvicorn joblib scikit-learn pandas
Enter fullscreen mode Exit fullscreen mode

2. Create the API

from fastapi import FastAPI 
from pydantic import BaseModel
import joblib
import numpy as np 

app = FastAPI(title='Customer Lifetime Value Prediction API')

# Load the saved model and feature schema
model = joblib.load('CLV_model.joblib')
feature_name = joblib.load('modelfeatures.joblib')

# Define the input schema (adjust fields to match your actual dataset columns)
class CLVinput(BaseModel):
    Customer_Age: int
    Annual_Income: float
    Tenure_Months: int
    Monthly_Spend: float
    Visits_Per_Month: int
    Avg_Basket_Value: float
    Support_Tickets: int

@app.get("/")
def health_check():
    return {"status": "API is running"}

@app.post('/predict-CLV')
def predict_CLV(data:CLVinput):
    x = np.array([[getattr(data,f) for f in feature_name]])
    prediction = model.predict(x)[0]
    return{'predicted_CLV': prediction}
Enter fullscreen mode Exit fullscreen mode

3. Run the Server Locally

uvicorn app:app --reload
Enter fullscreen mode Exit fullscreen mode

Your API will be live at http://localhost:8000. You can test it at http://localhost:8000/docs. FastAPI generates interactive API documentation automatically.

4. Deploy to the Cloud
For production, deploy the API to a cloud provider. Here's a quick overview:
Railway or Render (simplest): Push your code to GitHub and connect the repo. Both platforms auto-detect Python apps and handle deployment with minimal configuration. Add a requirements.txt file:

fastapi
uvicorn
joblib
scikit-learn
pandas

Summary

Here's the end-to-end workflow we covered:

  1. Load and explore the customer dataset
  2. Prepare features by separating inputs from the CLV target
  3. Train two models, Linear Regression and Random Forest and compare them using RMSE and R²
  4. Save the best model using joblib
  5. Deploy via FastAPI with a /predict endpoint that accepts customer data and returns a CLV estimate

Predicting Customer Lifetime Value turns raw customer data into a strategic business asset. With a deployed model, your sales and marketing teams can make real-time decisions based on predicted value, not just historical behaviour.

Top comments (0)