DEV Community

Adeline Makokha [AWS Hero]
Adeline Makokha [AWS Hero]

Posted on

I Built a ML Churn Predictor in Minutes- Here's How Kiro Made It Possible

Customer churn is one of the most expensive problems in the telecom industry. Acquiring a new customer costs 5–10× more than retaining an existing one, yet most companies only discover a customer has churned after they've already left. The goal of this project is to flip that, give analysts a tool to identify at-risk customers before they churn, so retention teams can act proactively.

What would normally take days of planning, scaffolding, and wiring together took a fraction of the time, because I built it with Kiro, an AI-powered development environment that thinks in specs, not just code completions.

In this article I'll walk through building a complete churn prediction web application from scratch using Python, Flask, scikit-learn, and Plotly. By the end you'll have a working app that:

  • Accepts CSV uploads of customer data
  • Runs a Random Forest churn prediction model
  • Visualises results with three interactive charts
  • Lets you browse, filter, and sort at-risk customers
  • Exports results to CSV for downstream use

How Kiro Accelerated This Build

Before diving into the code, it's worth explaining why this came together so fast.

Most AI coding tools are reactive meaning you write code, they autocomplete. Kiro works differently. It starts with a spec-driven workflow where you describe what you want to build, and Kiro helps you think through requirements, design, and implementation tasks before a single line of code is written.

Here's exactly how this project unfolded:

1. Requirements in minutes, not hours

I described the project in plain English, "a telecom customer churn prediction website using Python" and Kiro generated a full requirements document covering 7 requirement areas with precise, testable acceptance criteria in EARS format. Things like:

"IF a Dataset contains up to 10,000 Customer Records, THEN THE Predictor SHALL complete prediction within 30 seconds."

No ambiguity. No back-and-forth. Edge cases I hadn't even thought about, like what happens when tenure = 0, or when a CSV is valid but contains zero data rows, were already covered.

2. Technical design with 15 correctness properties

From the requirements, Kiro produced a full technical design document i.e, component interfaces with Python signatures, data models, an architecture diagram, Flask route table, and 15 formal correctness properties to be verified with property-based tests. For example:

"For any array of churn scores and any threshold in [0.0, 1.0], compute_churn_rate SHALL return exactly round((count of scores >= threshold / total count) * 100, 2)."

This is the kind of rigour that usually only happens on large teams with dedicated QA. Kiro baked it in from the start.

3. Implementation tasks, automatically sequenced

Kiro then broke the design into a dependency-ordered task list, 13 top-level tasks across 8 parallel waves, from project scaffolding through to integration tests. Each task referenced specific requirements for traceability.

4. Code generation that actually matches the spec

With the spec in place, Kiro generated all the Python modules, Flask routes, Jinja2 templates, JavaScript, and sample data and the code matched the design document precisely. No hallucinated APIs, no mismatched interfaces.

The result is a production-quality app with 146 passing tests (unit, integration, and property-based) generated from a single plain English description.

The takeaway: Kiro doesn't just write code faster. It helps you build the right thing by front-loading the thinking. The spec becomes the source of truth, and the code follows from it.



What We're Building

Here's the full feature set at a glance:

Feature Details
CSV Upload Validates format, size (≤50 MB), required columns, and row-level data quality
Churn Prediction Random Forest model, configurable threshold (default 0.5)
Dashboard Summary stats + 3 Plotly charts
At-Risk Table Paginated (25/page), sortable, filterable
Export Download results as CSV
Model Info Displays model name, version, and training date

Architecture Overview

The app follows a clean separation of concerns. Each responsibility lives in its own module:

Request flows:

  1. UploadValidator checks the file → valid rows stored in AppState
  2. PredictPredictor scores every row → Visualizer builds chart specs → stored in AppState
  3. ExportExporter serialises results → streamed as file download
  4. StartupModelLoader loads model.joblib once; if it fails, prediction is disabled

Project Structure

telecom-churn-app/
├── app.py               # Flask routes and AppState
├── validator.py         # CSV upload validation
├── predictor.py         # Churn scoring
├── visualizer.py        # Plotly chart builders
├── exporter.py          # CSV export
├── model_loader.py      # joblib model loading
├── table_helpers.py     # Pagination, sort, filter
├── generate_model.py    # One-time model training script
├── requirements.txt
├── data/
│   ├── sample_customers.csv   # 200 rows
│   └── sample_small.csv       # 20 rows for quick testing
├── templates/
│   ├── base.html
│   └── dashboard.html
└── static/
    └── app.js
Enter fullscreen mode Exit fullscreen mode

Getting Started

Prerequisites

  • Python 3.11+
  • pip

Install dependencies

pip install -r requirements.txt
Enter fullscreen mode Exit fullscreen mode

requirements.txt:

Flask==3.0.3
pandas==2.2.2
numpy==1.26.4
scikit-learn==1.5.0
joblib==1.4.2
plotly==5.22.0
pytest==8.2.2
hypothesis==6.103.1
Enter fullscreen mode Exit fullscreen mode

Generate the model

python generate_model.py
# → Model saved to model.joblib
Enter fullscreen mode Exit fullscreen mode

Run the app

python app.py
# → Running on http://localhost:5000
Enter fullscreen mode Exit fullscreen mode

The Data

The app expects a CSV with these five columns:

Column Type Rules
customer_id string Required
tenure numeric 0 – 999 (months)
monthly_charges numeric > 0
total_charges numeric > 0
contract_type string "Month-to-month", "One year", "Two year"

A sample row from data/sample_small.csv:

customer_id,tenure,monthly_charges,total_charges,contract_type
CUST0001,69,113.04,7560.84,One year
CUST0004,41,59.11,2325.40,Month-to-month
CUST0007,68,113.62,7797.64,Two year
Enter fullscreen mode Exit fullscreen mode

Step 1: Training the Model (generate_model.py)

We generate 1,000 rows of synthetic training data where churn probability is driven by three realistic signals i.e, short tenure, high monthly charges, and month-to-month contracts.


def generate_training_data(n: int = 1000, seed: int = 42) -> pd.DataFrame:
    rng = np.random.default_rng(seed)

    tenure = rng.integers(0, 73, size=n)
    monthly_charges = rng.uniform(20.0, 120.0, size=n).round(2)
    total_charges = (tenure * monthly_charges * rng.uniform(0.95, 1.05, size=n)).round(2)

    contract_type = rng.choice(
        ["Month-to-month", "One year", "Two year"],
        size=n,
        p=[0.5, 0.3, 0.2]
    )

    # Churn probability: higher for short tenure, month-to-month, high charges
    churn_prob = (
        0.4 * (1 - tenure / 72)
        + 0.3 * (monthly_charges - 20) / 100
        + 0.3 * (contract_type == "Month-to-month").astype(float)
    )
    churn_prob = np.clip(churn_prob, 0.05, 0.95)
    churn = rng.binomial(1, churn_prob, size=n)
    ...
Enter fullscreen mode Exit fullscreen mode

After training a RandomForestClassifier, we attach metadata directly to the model object before saving with joblib:

model.metadata = {
    "name": "RandomForestChurnModel",
    "version": "1.0.0",
    "training_date": "2024-01-15",
}
joblib.dump(model, "model.joblib")
Enter fullscreen mode Exit fullscreen mode

This keeps the model and its metadata in a single file and no separate config needed.


Step 2: Loading the Model (model_loader.py)

The model is loaded once at startup. If the file is missing or corrupt, the app enters a degraded state where prediction is disabled but everything else still works.


@dataclass
class ModelMetadata:
    name: str
    version: str
    training_date: date  # displayed as ISO 8601 YYYY-MM-DD

@dataclass
class LoadedModel:
    model: object
    metadata: ModelMetadata

class ModelLoadError(Exception):
    pass

def load_model(path: str) -> LoadedModel:
    try:
        model = joblib.load(path)
    except FileNotFoundError:
        raise ModelLoadError(f"Model file not found: {path}")
    except Exception as e:
        raise ModelLoadError(f"Failed to load model: {e}")

    raw_meta = model.metadata
    training_date = date.fromisoformat(raw_meta["training_date"])

    return LoadedModel(
        model=model,
        metadata=ModelMetadata(
            name=raw_meta["name"],
            version=raw_meta["version"],
            training_date=training_date,
        )
    )
Enter fullscreen mode Exit fullscreen mode

In app.py, this runs before the first request is served:

app_state = AppState()

def _load_model_on_startup():
    try:
        app_state.loaded_model = load_model(MODEL_PATH)
    except ModelLoadError as e:
        app_state.model_load_error = str(e)

_load_model_on_startup()
Enter fullscreen mode Exit fullscreen mode

Step 3: Validating Uploads (validator.py)

The validator runs a multi-stage pipeline. Each stage can fail fast with a clear error message:

The row-level rules are:

  • monthly_charges and total_charges must be numeric and > 0
  • tenure must be numeric, ≥ 0, and ≤ 999 (tenure = 0 is valid)


def _validate_rows(df: pd.DataFrame) -> tuple[pd.DataFrame, int]:
    work_df = df.copy()
    work_df["_tenure_num"]  = pd.to_numeric(work_df["tenure"],          errors="coerce")
    work_df["_monthly_num"] = pd.to_numeric(work_df["monthly_charges"], errors="coerce")
    work_df["_total_num"]   = pd.to_numeric(work_df["total_charges"],   errors="coerce")

    valid_mask = (
        work_df["_tenure_num"].notna()
        & (work_df["_tenure_num"] >= 0)
        & (work_df["_tenure_num"] <= 999)
        & work_df["_monthly_num"].notna()
        & (work_df["_monthly_num"] > 0)
        & work_df["_total_num"].notna()
        & (work_df["_total_num"] > 0)
    )

    valid_df = df[valid_mask].copy()
    invalid_count = int((~valid_mask).sum())
    return valid_df, invalid_count
Enter fullscreen mode Exit fullscreen mode

If some rows are invalid but at least one is valid, the app warns the user and proceeds with the clean rows. If all rows are invalid, prediction is blocked.

The ValidationResult dataclass carries everything the route handler needs:

@dataclass
class ValidationResult:
    success: bool
    error_message: str | None = None
    warning_message: str | None = None
    dataframe: pd.DataFrame | None = None
    total_rows: int = 0
    valid_rows: int = 0
    invalid_rows: int = 0
Enter fullscreen mode Exit fullscreen mode

Step 4: Running Predictions (predictor.py)

The predictor one-hot encodes contract_type to match the training feature set, then calls predict_proba to get churn probabilities:


def predict(df: pd.DataFrame, model, threshold: float = 0.5) -> PredictionResult:
    try:
        feature_df = df[["tenure", "monthly_charges", "total_charges", "contract_type"]].copy()
        feature_df = pd.get_dummies(feature_df, columns=["contract_type"])

        # Ensure all contract type columns exist even if not in this batch
        for col in ["contract_type_Month-to-month", "contract_type_One year", "contract_type_Two year"]:
            if col not in feature_df.columns:
                feature_df[col] = 0

        feature_cols = ["tenure", "monthly_charges", "total_charges",
                        "contract_type_Month-to-month", "contract_type_One year", "contract_type_Two year"]
        X = feature_df[feature_cols].astype(float)

        probas = model.predict_proba(X)
        scores = probas[:, 1]  # probability of churn

        return PredictionResult(success=True, scores=scores,
                                customer_ids=df["customer_id"].tolist(),
                                threshold=threshold)
    except Exception as e:
        return PredictionResult(success=False, error_message=str(e), threshold=threshold)
Enter fullscreen mode Exit fullscreen mode

The churn rate formula is explicit and deterministic:

def compute_churn_rate(scores: np.ndarray, threshold: float) -> float:
    at_risk_count = int(np.sum(scores >= threshold))
    return round((at_risk_count / len(scores)) * 100, 2)
Enter fullscreen mode Exit fullscreen mode

Step 5: Visualising Results (visualizer.py)

Three Plotly charts are built server-side and serialised to JSON, then rendered client-side with Plotly.newPlot. This keeps the server stateless with respect to chart rendering.

Chart 1: At-Risk vs Non-At-Risk bar chart

At-Risk  ████████████████  87
Non-Risk ████████████████████████████████  113
Enter fullscreen mode Exit fullscreen mode

Chart 2: Churn Score Distribution (histogram)

Exactly 10 bins of width 0.1 spanning [0.0, 1.0]:


def build_score_histogram(scores: np.ndarray) -> dict:
    bin_edges = np.linspace(0.0, 1.0, 11)  # 11 edges = 10 bins
    counts, _ = np.histogram(scores, bins=bin_edges)
    bin_centers = [(bin_edges[i] + bin_edges[i+1]) / 2 for i in range(10)]

    fig = go.Figure(data=[go.Bar(x=bin_centers, y=counts.tolist(), width=0.09)])
    fig.update_layout(title="Churn Score Distribution", ...)
    return fig.to_dict()
Enter fullscreen mode Exit fullscreen mode

Chart 3: Churn Rate by Contract Type

Month-to-month  ████████████████████████  62.4%
One year        ████████  21.3%
Two year        ████  10.1%
Enter fullscreen mode Exit fullscreen mode

Step 6: The Flask Application (app.py)

All mutable state lives in a single AppState dataclass — a simple singleton for single-user deployments:


@dataclass
class AppState:
    dataset: pd.DataFrame | None = None
    prediction_result: PredictionResult | None = None
    threshold: float = 0.5
    loaded_model: LoadedModel | None = None
    model_load_error: str | None = None
    chart_specs: dict | None = None
Enter fullscreen mode Exit fullscreen mode

The six routes map cleanly to user actions:

Route Method Action
GET / GET Redirect to dashboard
GET /dashboard GET Render main page
POST /upload POST Validate and store CSV
POST /predict POST Run prediction
POST /threshold POST Update churn threshold
GET /export GET Stream CSV download

The upload route shows the validation pipeline in action:

@app.route("/upload", methods=["POST"])
def upload():
    uploaded_file = request.files["file"]
    file_bytes = uploaded_file.read()
    result = validate_upload(file_bytes, uploaded_file.filename, len(file_bytes))

    if not result.success:
        flash(result.error_message, "error")
        return redirect(url_for("dashboard"))

    # Clear previous results when new data is uploaded
    app_state.dataset = result.dataframe
    app_state.prediction_result = None
    app_state.chart_specs = None

    if result.warning_message:
        flash(result.warning_message, "warning")

    flash(f"File uploaded successfully. {result.valid_rows} customer record(s) loaded.", "success")
    return redirect(url_for("dashboard"))
Enter fullscreen mode Exit fullscreen mode

The threshold route validates the range before accepting the new value:

@app.route("/threshold", methods=["POST"])
def update_threshold():
    threshold_val = float(request.form.get("threshold", ""))

    if not validate_threshold(threshold_val):
        flash(f"Threshold {threshold_val} is out of range. Valid range is [0.0, 1.0].", "error")
        return redirect(url_for("dashboard"))

    app_state.threshold = threshold_val
    # Rebuild charts immediately if results exist
    if app_state.prediction_result and app_state.prediction_result.success:
        app_state.prediction_result.threshold = threshold_val
        app_state.chart_specs = _build_chart_specs(app_state.dataset, app_state.prediction_result)

    flash(f"Threshold updated to {threshold_val}.", "success")
    return redirect(url_for("dashboard"))
Enter fullscreen mode Exit fullscreen mode

Step 7: The Dashboard UI

The dashboard uses a two-column Bootstrap 5 layout: a narrow left sidebar for controls, and a wide right panel for results.

Chart data is injected into the page as JSON and rendered by Plotly client-side:

<!-- In dashboard.html -->
{% if has_results %}
<script id="chart-data" type="application/json">{{ chart_data_json | safe }}</script>
<script id="table-data" type="application/json">{{ at_risk_table_json | safe }}</script>
{% endif %}
Enter fullscreen mode Exit fullscreen mode
// In app.js
function renderCharts() {
  const chartData = JSON.parse(document.getElementById('chart-data').textContent);
  const config = { responsive: true, displayModeBar: false };

  Plotly.newPlot('chart-at-risk',   chartData.at_risk_bar.data,    chartData.at_risk_bar.layout,    config);
  Plotly.newPlot('chart-histogram', chartData.score_histogram.data, chartData.score_histogram.layout, config);
  Plotly.newPlot('chart-contract',  chartData.contract_type.data,   chartData.contract_type.layout,   config);
}
Enter fullscreen mode Exit fullscreen mode

Step 8: At-Risk Table: Pagination, Sort, Filter

The table helpers are pure Python functions, independently testable and reusable:


# table_helpers.py

def paginate(records: list, page: int, page_size: int = 25) -> list:
    start = (page - 1) * page_size
    return records[start : start + page_size]

def sort_records(records: list, column: str, direction: str) -> list:
    reverse = direction.lower() == "desc"
    return sorted(records, key=lambda r: (r.get(column) is None, r.get(column)), reverse=reverse)

def filter_records(records: list, search_term: str) -> list:
    if not search_term:
        return records
    term = search_term.lower()
    return [r for r in records if term in str(r.get("customer_id", "")).lower()]
Enter fullscreen mode Exit fullscreen mode

The client-side JavaScript mirrors this logic for instant interactivity without round-trips:

// Sort on column header click
document.querySelectorAll('#atRiskTable thead th[data-col]').forEach(th => {
  th.addEventListener('click', function () {
    const col = this.getAttribute('data-col');
    sortDirection = (sortColumn === col && sortDirection === 'asc') ? 'desc' : 'asc';
    sortColumn = col;
    currentPage = 1;
    renderTable();
  });
});

// Filter on search input
document.getElementById('tableSearch').addEventListener('input', function () {
  const term = this.value.toLowerCase();
  filteredRecords = allRecords.filter(r =>
    String(r.customer_id || '').toLowerCase().includes(term)
  );
  currentPage = 1;
  renderTable();
});
Enter fullscreen mode Exit fullscreen mode

Step 9: Exporting Results (exporter.py)

The export produces a clean CSV with a fixed column order. One detail worth noting: pandas serialises Python booleans as True/False (capitalised) by default, but the spec requires lowercase true/false. We handle this explicitly:

EXPORT_COLUMNS = ["customer_id", "churn_score", "is_at_risk",
                  "contract_type", "tenure", "monthly_charges"]

def build_export_dataframe(df, scores, at_risk_flags) -> pd.DataFrame:
    return pd.DataFrame({
        "customer_id":     df["customer_id"].values,
        "churn_score":     np.round(scores, 4),       # 4 decimal places
        "is_at_risk":      at_risk_flags.astype(bool),
        "contract_type":   df["contract_type"].values,
        "tenure":          df["tenure"].values,
        "monthly_charges": df["monthly_charges"].values,
    })[EXPORT_COLUMNS]

def to_csv_bytes(export_df: pd.DataFrame) -> bytes:
    out_df = export_df.copy()
    out_df["is_at_risk"] = out_df["is_at_risk"].map({True: "true", False: "false"})
    return out_df.to_csv(index=False).encode("utf-8")
Enter fullscreen mode Exit fullscreen mode

Sample export output:

customer_id,churn_score,is_at_risk,contract_type,tenure,monthly_charges
CUST0004,0.8231,true,Month-to-month,41,59.11
CUST0005,0.7654,true,Month-to-month,12,47.03
CUST0009,0.1203,false,One year,70,97.46
Enter fullscreen mode Exit fullscreen mode

Testing Strategy

The project uses two complementary testing approaches.

Example-based tests (pytest)

These cover specific scenarios and exact error messages:

# tests/unit/test_validator.py
def test_rejects_non_csv_file():
    result = validate_upload(b"some data", "data.xlsx", 100)
    assert result.success is False
    assert "xlsx" in result.error_message

def test_tenure_zero_is_valid():
    csv = b"customer_id,tenure,monthly_charges,total_charges,contract_type\n"
    csv += b"C001,0,50.0,0.01,Month-to-month\n"
    result = validate_upload(csv, "test.csv", len(csv))
    assert result.success is True
    assert result.valid_rows == 1
Enter fullscreen mode Exit fullscreen mode

Property-based tests (Hypothesis)

These verify universal correctness properties across thousands of generated inputs:

# tests/property/test_predictor_properties.py
from hypothesis import given, settings
import hypothesis.strategies as st

@given(
    scores=st.lists(st.floats(0.0, 1.0), min_size=1).map(np.array),
    threshold=st.floats(0.0, 1.0)
)
@settings(max_examples=200)
def test_classify_at_risk_consistency(scores, threshold):
    """
    Property 2: For any scores and threshold, classify_at_risk returns
    True iff score >= threshold — including identical scores.
    """
    flags = classify_at_risk(scores, threshold)
    for score, flag in zip(scores, flags):
        assert flag == (score >= threshold)
Enter fullscreen mode Exit fullscreen mode
@given(st.integers(min_value=0))
def test_file_size_boundary(size):
    """
    Property 9: _check_file_size returns True iff size <= 52,428,800.
    """
    result = _check_file_size(size)
    assert result == (size <= 52_428_800)
Enter fullscreen mode Exit fullscreen mode

Key Design Decisions

Why a singleton AppState instead of Flask sessions?
Sessions are limited to ~4 KB (cookie storage) and can't hold DataFrames. For a single-user analytics tool, a module-level singleton is simpler and more practical than a database or Redis cache.

Why Plotly JSON instead of server-rendered images?
Plotly charts are interactive, users can hover, zoom, and pan. Serialising chart specs as JSON and rendering client-side means the server doesn't need a headless browser or image generation library.

Why separate table_helpers.py?
Keeping pagination, sort, and filter as pure functions makes them trivially testable without spinning up a Flask test client. The JavaScript mirrors the same logic for instant client-side interactivity.

Why one-hot encode at prediction time?
The uploaded CSV may not contain all three contract types. Encoding at prediction time and filling missing columns with 0 ensures the feature vector always matches what the model was trained on.


Running the Full App

# 1. Install dependencies
pip install -r requirements.txt

# 2. Train and save the model
python generate_model.py

# 3. Start the server
python app.py
Enter fullscreen mode Exit fullscreen mode

Open http://localhost:5000, upload data/sample_customers.csv, and click Predict Churn.

You should see:

  • Summary stats (total customers, at-risk count, churn rate %)
  • Three interactive Plotly charts
  • A paginated, sortable, filterable table of at-risk customers
  • An export button to download the full results as CSV

What's Next

A few natural extensions from here:

  • User authentication - add Flask-Login for multi-user support with per-user state
  • Model retraining - add an admin route to upload new training data and retrain in-place
  • Scheduled batch jobs - use Celery + Redis to run predictions on a schedule and email results
  • Database persistence - swap the in-memory AppState for SQLAlchemy + PostgreSQL to persist results across restarts
  • SHAP explanations - add feature importance explanations per customer using the shap library

Source Code

The full source is available on GitHub: Agentic AI Kiro

The project includes:

  • All Python modules with docstrings
  • Sample CSV data (200 rows)
  • generate_model.py to reproduce the model
  • Unit, integration, and property-based tests

Try Kiro Yourself

If you want to build something like this or anything else - Kiro is worth trying. The spec-driven workflow changes how you approach a project. Instead of diving straight into code and figuring out the design as you go, you start with a clear picture of what you're building and why. The code becomes the easy part.

The entire requirements document, technical design, task list, and implementation for this project came from a single prompt. That's the difference.


Built with Flask, pandas, scikit-learn, and Plotly. Spec-driven development powered by Kiro. Tested with pytest and Hypothesis.

Top comments (0)