DEV Community

Young Gao
Young Gao

Posted on

CI/CD for ML Models: From Training Notebooks to Production Deployment in 2026

CI/CD for ML Models: From Training Notebooks to Production Deployment in 2026

Deploying ML models isn't like deploying web apps. Models have data dependencies, training artifacts, hardware requirements, and failure modes that traditional CI/CD pipelines don't handle. Most teams either deploy models manually (risky) or bolt ML onto web app pipelines (broken).

This guide shows how to build ML-specific CI/CD pipelines that actually work.

Why Standard CI/CD Breaks for ML

Web App Pipeline:        ML Pipeline:
Code → Build → Test →    Code → Build → Test →
Deploy → Monitor         Train → Validate → Package →
                         Deploy → Monitor → Retrain
Enter fullscreen mode Exit fullscreen mode

ML adds three challenges:

  1. Artifacts are large — models are 100MB-100GB, not 10MB Docker images
  2. Tests need data — you can't unit test a model without test datasets
  3. Rollbacks need both code AND model — deploying old code with new model (or vice versa) breaks things

The Pipeline Architecture

# .github/workflows/ml-pipeline.yml
name: ML Model Pipeline

on:
  push:
    paths:
      - 'models/**'
      - 'training/**'
      - 'serving/**'
  workflow_dispatch:
    inputs:
      retrain:
        description: 'Force retraining'
        type: boolean
        default: false

jobs:
  validate-data:
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v4

      - name: Validate training data schema
        run: |
          python -m pytest tests/data/ -v \
            --tb=short \
            -k "test_schema or test_distribution"

      - name: Check for data drift
        run: |
          python scripts/check_drift.py \
            --reference data/reference_stats.json \
            --current data/latest/ \
            --threshold 0.05

  train:
    needs: validate-data
    runs-on: [self-hosted, gpu]
    steps:
      - uses: actions/checkout@v4

      - name: Train model
        env:
          MLFLOW_TRACKING_URI: ${{ secrets.MLFLOW_URI }}
          WANDB_API_KEY: ${{ secrets.WANDB_KEY }}
        run: |
          python training/train.py \
            --config configs/production.yaml \
            --experiment-name "ci-${{ github.sha }}" \
            --output-dir artifacts/

      - name: Upload model artifact
        uses: actions/upload-artifact@v4
        with:
          name: model-${{ github.sha }}
          path: artifacts/
          retention-days: 30

  evaluate:
    needs: train
    runs-on: [self-hosted, gpu]
    steps:
      - uses: actions/checkout@v4

      - name: Download model
        uses: actions/download-artifact@v4
        with:
          name: model-${{ github.sha }}
          path: artifacts/

      - name: Run evaluation suite
        run: |
          python evaluation/evaluate.py \
            --model artifacts/model.pt \
            --test-data data/test/ \
            --output evaluation_report.json

      - name: Check quality gates
        run: |
          python scripts/check_gates.py \
            --report evaluation_report.json \
            --min-accuracy 0.92 \
            --max-latency-p99 100 \
            --max-model-size 500MB
Enter fullscreen mode Exit fullscreen mode

Model Versioning That Works

Don't use Git LFS for models. Use a proper model registry.

# model_registry.py
import mlflow
from dataclasses import dataclass
from datetime import datetime

@dataclass
class ModelVersion:
    name: str
    version: int
    stage: str  # "staging" | "production" | "archived"
    metrics: dict
    git_sha: str
    created_at: datetime

class ModelRegistry:
    def __init__(self, tracking_uri: str):
        mlflow.set_tracking_uri(tracking_uri)

    def register(self, model_path: str, name: str,
                 metrics: dict, git_sha: str) -> ModelVersion:
        """Register a trained model with its metrics."""
        with mlflow.start_run():
            mlflow.log_metrics(metrics)
            mlflow.log_param("git_sha", git_sha)

            # Log model with signature for input validation
            model_info = mlflow.pytorch.log_model(
                pytorch_model=model_path,
                artifact_path="model",
                registered_model_name=name,
            )

        # Get the new version number
        client = mlflow.MlflowClient()
        versions = client.get_latest_versions(name)
        latest = max(versions, key=lambda v: v.version)

        return ModelVersion(
            name=name,
            version=latest.version,
            stage="staging",
            metrics=metrics,
            git_sha=git_sha,
            created_at=datetime.utcnow(),
        )

    def promote(self, name: str, version: int, stage: str):
        """Move model to a new stage (staging → production)."""
        client = mlflow.MlflowClient()
        client.transition_model_version_stage(
            name=name, version=version, stage=stage,
        )

    def get_production(self, name: str) -> ModelVersion | None:
        """Get the current production model version."""
        client = mlflow.MlflowClient()
        versions = client.get_latest_versions(name, stages=["Production"])
        if not versions:
            return None
        v = versions[0]
        run = client.get_run(v.run_id)
        return ModelVersion(
            name=name,
            version=v.version,
            stage="production",
            metrics=run.data.metrics,
            git_sha=run.data.params.get("git_sha", "unknown"),
            created_at=datetime.fromisoformat(v.creation_timestamp),
        )
Enter fullscreen mode Exit fullscreen mode

Automated Model Testing

ML models need three types of tests:

# tests/test_model.py
import pytest
import torch
import numpy as np

class TestModelQuality:
    """Tests that run against the trained model."""

    @pytest.fixture
    def model(self):
        return torch.load("artifacts/model.pt", weights_only=True)

    @pytest.fixture
    def test_data(self):
        return load_test_dataset("data/test/")

    def test_accuracy_above_threshold(self, model, test_data):
        """Model must maintain minimum accuracy."""
        predictions = model.predict(test_data.features)
        accuracy = (predictions == test_data.labels).mean()
        assert accuracy >= 0.92, f"Accuracy {accuracy:.4f} below threshold 0.92"

    def test_no_class_collapse(self, model, test_data):
        """Model must predict all classes, not just the majority."""
        predictions = model.predict(test_data.features)
        unique_classes = set(predictions.tolist())
        expected_classes = set(test_data.labels.unique().tolist())
        missing = expected_classes - unique_classes
        assert not missing, f"Model never predicts classes: {missing}"

    def test_latency_under_budget(self, model):
        """Single inference must complete within latency budget."""
        dummy_input = torch.randn(1, 768)  # Single sample
        times = []
        for _ in range(100):
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()
            with torch.no_grad():
                model(dummy_input)
            end.record()
            torch.cuda.synchronize()
            times.append(start.elapsed_time(end))
        p99 = np.percentile(times, 99)
        assert p99 < 100, f"p99 latency {p99:.1f}ms exceeds 100ms budget"

    def test_model_size_under_limit(self):
        """Model file must fit deployment constraints."""
        import os
        size_mb = os.path.getsize("artifacts/model.pt") / (1024 * 1024)
        assert size_mb < 500, f"Model size {size_mb:.0f}MB exceeds 500MB limit"


class TestModelRobustness:
    """Tests for edge cases and adversarial inputs."""

    @pytest.fixture
    def model(self):
        return torch.load("artifacts/model.pt", weights_only=True)

    def test_handles_empty_input(self, model):
        """Model should handle edge case inputs gracefully."""
        empty = torch.zeros(1, 768)
        result = model(empty)
        assert result is not None
        assert not torch.isnan(result).any()

    def test_handles_extreme_values(self, model):
        """Model shouldn't produce NaN on large inputs."""
        extreme = torch.ones(1, 768) * 1e6
        result = model(extreme)
        assert not torch.isnan(result).any()
        assert not torch.isinf(result).any()

    def test_deterministic_inference(self, model):
        """Same input should produce same output."""
        sample = torch.randn(1, 768)
        result1 = model(sample)
        result2 = model(sample)
        assert torch.allclose(result1, result2)
Enter fullscreen mode Exit fullscreen mode

Canary Deployment for Models

# canary_deployer.py
import httpx
import time
from dataclasses import dataclass

@dataclass
class CanaryConfig:
    initial_weight: float = 0.05     # Start with 5% traffic
    step_weight: float = 0.10        # Increase by 10% each step
    step_interval: int = 300         # 5 minutes between steps
    error_threshold: float = 0.01    # Rollback if error rate > 1%
    latency_threshold_ms: float = 150  # Rollback if p99 > 150ms

class CanaryDeployer:
    def __init__(self, config: CanaryConfig, metrics_client):
        self.config = config
        self.metrics = metrics_client

    async def deploy(self, model_version: str) -> bool:
        """Progressive canary deployment with automatic rollback."""
        weight = self.config.initial_weight

        # Deploy canary with initial weight
        await self._set_traffic_weight("canary", weight)
        print(f"Canary deployed at {weight*100:.0f}% traffic")

        while weight < 1.0:
            # Wait for metrics to accumulate
            time.sleep(self.config.step_interval)

            # Check health metrics
            error_rate = await self.metrics.get_error_rate("canary", window="5m")
            p99_latency = await self.metrics.get_latency_p99("canary", window="5m")

            if error_rate > self.config.error_threshold:
                print(f"ROLLBACK: Error rate {error_rate:.4f} > {self.config.error_threshold}")
                await self._rollback()
                return False

            if p99_latency > self.config.latency_threshold_ms:
                print(f"ROLLBACK: p99 latency {p99_latency:.0f}ms > {self.config.latency_threshold_ms}ms")
                await self._rollback()
                return False

            # Increase traffic
            weight = min(weight + self.config.step_weight, 1.0)
            await self._set_traffic_weight("canary", weight)
            print(f"Canary weight increased to {weight*100:.0f}%")

        # Full rollout successful
        await self._promote_canary()
        print(f"Model {model_version} fully deployed")
        return True

    async def _set_traffic_weight(self, variant: str, weight: float):
        """Update load balancer traffic split."""
        async with httpx.AsyncClient() as client:
            await client.put(
                f"{self.lb_url}/routes/model/weights",
                json={"canary": weight, "stable": 1 - weight},
            )

    async def _rollback(self):
        """Immediately route all traffic back to stable."""
        await self._set_traffic_weight("stable", 1.0)

    async def _promote_canary(self):
        """Make canary the new stable version."""
        await self._set_traffic_weight("canary", 1.0)
Enter fullscreen mode Exit fullscreen mode

Data Drift Detection

Models degrade silently when input data changes. Catch it in CI.

# scripts/check_drift.py
import json
import numpy as np
from scipy import stats

def check_drift(reference_path: str, current_path: str,
                threshold: float = 0.05) -> dict:
    """Compare current data distribution against reference baseline."""
    with open(reference_path) as f:
        reference = json.load(f)

    current_stats = compute_stats(current_path)
    drift_report = {"features": {}, "drifted": False}

    for feature_name, ref_stats in reference["features"].items():
        if feature_name not in current_stats:
            drift_report["features"][feature_name] = {
                "status": "MISSING",
                "severity": "critical",
            }
            drift_report["drifted"] = True
            continue

        curr = current_stats[feature_name]

        # KS test for continuous features
        if ref_stats["type"] == "numeric":
            ks_stat, p_value = stats.ks_2samp(
                ref_stats["sample"], curr["sample"]
            )
            drifted = p_value < threshold
        # Chi-squared for categorical
        else:
            chi2, p_value = stats.chisquare(
                curr["counts"], ref_stats["counts"]
            )
            drifted = p_value < threshold

        drift_report["features"][feature_name] = {
            "status": "DRIFT" if drifted else "OK",
            "p_value": float(p_value),
            "severity": "warning" if drifted else "none",
        }
        if drifted:
            drift_report["drifted"] = True

    return drift_report

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--reference", required=True)
    parser.add_argument("--current", required=True)
    parser.add_argument("--threshold", type=float, default=0.05)
    args = parser.parse_args()

    report = check_drift(args.reference, args.current, args.threshold)
    print(json.dumps(report, indent=2))

    if report["drifted"]:
        print("\n⚠ Data drift detected! Review before deploying.")
        exit(1)
Enter fullscreen mode Exit fullscreen mode

The Complete Pipeline Flow

1. Code Push
   ├── validate-data: schema check + drift detection
   ├── lint-and-test: standard code quality
   │
2. Training (on GPU runner)
   ├── train model with experiment tracking
   ├── register in model registry (staging)
   │
3. Evaluation
   ├── accuracy gate (>= 0.92)
   ├── latency gate (p99 < 100ms)
   ├── size gate (< 500MB)
   ├── robustness tests (NaN, edge cases)
   │
4. Deployment
   ├── canary at 5% traffic
   ├── monitor error rate + latency
   ├── step up 10% every 5 min
   ├── full rollout or auto-rollback
   │
5. Post-Deploy
   ├── continuous drift monitoring
   ├── A/B metrics comparison
   └── automated retrain trigger
Enter fullscreen mode Exit fullscreen mode

Key Takeaways

  1. Version models separately from code — use a model registry, not Git LFS.
  2. Test models like software — accuracy thresholds, latency budgets, edge cases.
  3. Detect data drift in CI — models silently degrade when input distributions shift.
  4. Deploy with canary rollouts — gradual traffic shifting with automatic rollback.
  5. Train on GPU runners — use self-hosted runners or cloud GPU instances for CI training.

ML deployment isn't a solved problem, but these patterns handle 90% of production failure modes. Start with quality gates and canary deployments — they catch the most issues with the least infrastructure.


Based on ML deployment pipelines processing 50M+ predictions/day across multiple model versions.

Top comments (0)