CI/CD for ML Models: From Training Notebooks to Production Deployment in 2026
Deploying ML models isn't like deploying web apps. Models have data dependencies, training artifacts, hardware requirements, and failure modes that traditional CI/CD pipelines don't handle. Most teams either deploy models manually (risky) or bolt ML onto web app pipelines (broken).
This guide shows how to build ML-specific CI/CD pipelines that actually work.
Why Standard CI/CD Breaks for ML
Web App Pipeline: ML Pipeline:
Code → Build → Test → Code → Build → Test →
Deploy → Monitor Train → Validate → Package →
Deploy → Monitor → Retrain
ML adds three challenges:
- Artifacts are large — models are 100MB-100GB, not 10MB Docker images
- Tests need data — you can't unit test a model without test datasets
- Rollbacks need both code AND model — deploying old code with new model (or vice versa) breaks things
The Pipeline Architecture
# .github/workflows/ml-pipeline.yml
name: ML Model Pipeline
on:
push:
paths:
- 'models/**'
- 'training/**'
- 'serving/**'
workflow_dispatch:
inputs:
retrain:
description: 'Force retraining'
type: boolean
default: false
jobs:
validate-data:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Validate training data schema
run: |
python -m pytest tests/data/ -v \
--tb=short \
-k "test_schema or test_distribution"
- name: Check for data drift
run: |
python scripts/check_drift.py \
--reference data/reference_stats.json \
--current data/latest/ \
--threshold 0.05
train:
needs: validate-data
runs-on: [self-hosted, gpu]
steps:
- uses: actions/checkout@v4
- name: Train model
env:
MLFLOW_TRACKING_URI: ${{ secrets.MLFLOW_URI }}
WANDB_API_KEY: ${{ secrets.WANDB_KEY }}
run: |
python training/train.py \
--config configs/production.yaml \
--experiment-name "ci-${{ github.sha }}" \
--output-dir artifacts/
- name: Upload model artifact
uses: actions/upload-artifact@v4
with:
name: model-${{ github.sha }}
path: artifacts/
retention-days: 30
evaluate:
needs: train
runs-on: [self-hosted, gpu]
steps:
- uses: actions/checkout@v4
- name: Download model
uses: actions/download-artifact@v4
with:
name: model-${{ github.sha }}
path: artifacts/
- name: Run evaluation suite
run: |
python evaluation/evaluate.py \
--model artifacts/model.pt \
--test-data data/test/ \
--output evaluation_report.json
- name: Check quality gates
run: |
python scripts/check_gates.py \
--report evaluation_report.json \
--min-accuracy 0.92 \
--max-latency-p99 100 \
--max-model-size 500MB
Model Versioning That Works
Don't use Git LFS for models. Use a proper model registry.
# model_registry.py
import mlflow
from dataclasses import dataclass
from datetime import datetime
@dataclass
class ModelVersion:
name: str
version: int
stage: str # "staging" | "production" | "archived"
metrics: dict
git_sha: str
created_at: datetime
class ModelRegistry:
def __init__(self, tracking_uri: str):
mlflow.set_tracking_uri(tracking_uri)
def register(self, model_path: str, name: str,
metrics: dict, git_sha: str) -> ModelVersion:
"""Register a trained model with its metrics."""
with mlflow.start_run():
mlflow.log_metrics(metrics)
mlflow.log_param("git_sha", git_sha)
# Log model with signature for input validation
model_info = mlflow.pytorch.log_model(
pytorch_model=model_path,
artifact_path="model",
registered_model_name=name,
)
# Get the new version number
client = mlflow.MlflowClient()
versions = client.get_latest_versions(name)
latest = max(versions, key=lambda v: v.version)
return ModelVersion(
name=name,
version=latest.version,
stage="staging",
metrics=metrics,
git_sha=git_sha,
created_at=datetime.utcnow(),
)
def promote(self, name: str, version: int, stage: str):
"""Move model to a new stage (staging → production)."""
client = mlflow.MlflowClient()
client.transition_model_version_stage(
name=name, version=version, stage=stage,
)
def get_production(self, name: str) -> ModelVersion | None:
"""Get the current production model version."""
client = mlflow.MlflowClient()
versions = client.get_latest_versions(name, stages=["Production"])
if not versions:
return None
v = versions[0]
run = client.get_run(v.run_id)
return ModelVersion(
name=name,
version=v.version,
stage="production",
metrics=run.data.metrics,
git_sha=run.data.params.get("git_sha", "unknown"),
created_at=datetime.fromisoformat(v.creation_timestamp),
)
Automated Model Testing
ML models need three types of tests:
# tests/test_model.py
import pytest
import torch
import numpy as np
class TestModelQuality:
"""Tests that run against the trained model."""
@pytest.fixture
def model(self):
return torch.load("artifacts/model.pt", weights_only=True)
@pytest.fixture
def test_data(self):
return load_test_dataset("data/test/")
def test_accuracy_above_threshold(self, model, test_data):
"""Model must maintain minimum accuracy."""
predictions = model.predict(test_data.features)
accuracy = (predictions == test_data.labels).mean()
assert accuracy >= 0.92, f"Accuracy {accuracy:.4f} below threshold 0.92"
def test_no_class_collapse(self, model, test_data):
"""Model must predict all classes, not just the majority."""
predictions = model.predict(test_data.features)
unique_classes = set(predictions.tolist())
expected_classes = set(test_data.labels.unique().tolist())
missing = expected_classes - unique_classes
assert not missing, f"Model never predicts classes: {missing}"
def test_latency_under_budget(self, model):
"""Single inference must complete within latency budget."""
dummy_input = torch.randn(1, 768) # Single sample
times = []
for _ in range(100):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
with torch.no_grad():
model(dummy_input)
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end))
p99 = np.percentile(times, 99)
assert p99 < 100, f"p99 latency {p99:.1f}ms exceeds 100ms budget"
def test_model_size_under_limit(self):
"""Model file must fit deployment constraints."""
import os
size_mb = os.path.getsize("artifacts/model.pt") / (1024 * 1024)
assert size_mb < 500, f"Model size {size_mb:.0f}MB exceeds 500MB limit"
class TestModelRobustness:
"""Tests for edge cases and adversarial inputs."""
@pytest.fixture
def model(self):
return torch.load("artifacts/model.pt", weights_only=True)
def test_handles_empty_input(self, model):
"""Model should handle edge case inputs gracefully."""
empty = torch.zeros(1, 768)
result = model(empty)
assert result is not None
assert not torch.isnan(result).any()
def test_handles_extreme_values(self, model):
"""Model shouldn't produce NaN on large inputs."""
extreme = torch.ones(1, 768) * 1e6
result = model(extreme)
assert not torch.isnan(result).any()
assert not torch.isinf(result).any()
def test_deterministic_inference(self, model):
"""Same input should produce same output."""
sample = torch.randn(1, 768)
result1 = model(sample)
result2 = model(sample)
assert torch.allclose(result1, result2)
Canary Deployment for Models
# canary_deployer.py
import httpx
import time
from dataclasses import dataclass
@dataclass
class CanaryConfig:
initial_weight: float = 0.05 # Start with 5% traffic
step_weight: float = 0.10 # Increase by 10% each step
step_interval: int = 300 # 5 minutes between steps
error_threshold: float = 0.01 # Rollback if error rate > 1%
latency_threshold_ms: float = 150 # Rollback if p99 > 150ms
class CanaryDeployer:
def __init__(self, config: CanaryConfig, metrics_client):
self.config = config
self.metrics = metrics_client
async def deploy(self, model_version: str) -> bool:
"""Progressive canary deployment with automatic rollback."""
weight = self.config.initial_weight
# Deploy canary with initial weight
await self._set_traffic_weight("canary", weight)
print(f"Canary deployed at {weight*100:.0f}% traffic")
while weight < 1.0:
# Wait for metrics to accumulate
time.sleep(self.config.step_interval)
# Check health metrics
error_rate = await self.metrics.get_error_rate("canary", window="5m")
p99_latency = await self.metrics.get_latency_p99("canary", window="5m")
if error_rate > self.config.error_threshold:
print(f"ROLLBACK: Error rate {error_rate:.4f} > {self.config.error_threshold}")
await self._rollback()
return False
if p99_latency > self.config.latency_threshold_ms:
print(f"ROLLBACK: p99 latency {p99_latency:.0f}ms > {self.config.latency_threshold_ms}ms")
await self._rollback()
return False
# Increase traffic
weight = min(weight + self.config.step_weight, 1.0)
await self._set_traffic_weight("canary", weight)
print(f"Canary weight increased to {weight*100:.0f}%")
# Full rollout successful
await self._promote_canary()
print(f"Model {model_version} fully deployed")
return True
async def _set_traffic_weight(self, variant: str, weight: float):
"""Update load balancer traffic split."""
async with httpx.AsyncClient() as client:
await client.put(
f"{self.lb_url}/routes/model/weights",
json={"canary": weight, "stable": 1 - weight},
)
async def _rollback(self):
"""Immediately route all traffic back to stable."""
await self._set_traffic_weight("stable", 1.0)
async def _promote_canary(self):
"""Make canary the new stable version."""
await self._set_traffic_weight("canary", 1.0)
Data Drift Detection
Models degrade silently when input data changes. Catch it in CI.
# scripts/check_drift.py
import json
import numpy as np
from scipy import stats
def check_drift(reference_path: str, current_path: str,
threshold: float = 0.05) -> dict:
"""Compare current data distribution against reference baseline."""
with open(reference_path) as f:
reference = json.load(f)
current_stats = compute_stats(current_path)
drift_report = {"features": {}, "drifted": False}
for feature_name, ref_stats in reference["features"].items():
if feature_name not in current_stats:
drift_report["features"][feature_name] = {
"status": "MISSING",
"severity": "critical",
}
drift_report["drifted"] = True
continue
curr = current_stats[feature_name]
# KS test for continuous features
if ref_stats["type"] == "numeric":
ks_stat, p_value = stats.ks_2samp(
ref_stats["sample"], curr["sample"]
)
drifted = p_value < threshold
# Chi-squared for categorical
else:
chi2, p_value = stats.chisquare(
curr["counts"], ref_stats["counts"]
)
drifted = p_value < threshold
drift_report["features"][feature_name] = {
"status": "DRIFT" if drifted else "OK",
"p_value": float(p_value),
"severity": "warning" if drifted else "none",
}
if drifted:
drift_report["drifted"] = True
return drift_report
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--reference", required=True)
parser.add_argument("--current", required=True)
parser.add_argument("--threshold", type=float, default=0.05)
args = parser.parse_args()
report = check_drift(args.reference, args.current, args.threshold)
print(json.dumps(report, indent=2))
if report["drifted"]:
print("\n⚠ Data drift detected! Review before deploying.")
exit(1)
The Complete Pipeline Flow
1. Code Push
├── validate-data: schema check + drift detection
├── lint-and-test: standard code quality
│
2. Training (on GPU runner)
├── train model with experiment tracking
├── register in model registry (staging)
│
3. Evaluation
├── accuracy gate (>= 0.92)
├── latency gate (p99 < 100ms)
├── size gate (< 500MB)
├── robustness tests (NaN, edge cases)
│
4. Deployment
├── canary at 5% traffic
├── monitor error rate + latency
├── step up 10% every 5 min
├── full rollout or auto-rollback
│
5. Post-Deploy
├── continuous drift monitoring
├── A/B metrics comparison
└── automated retrain trigger
Key Takeaways
- Version models separately from code — use a model registry, not Git LFS.
- Test models like software — accuracy thresholds, latency budgets, edge cases.
- Detect data drift in CI — models silently degrade when input distributions shift.
- Deploy with canary rollouts — gradual traffic shifting with automatic rollback.
- Train on GPU runners — use self-hosted runners or cloud GPU instances for CI training.
ML deployment isn't a solved problem, but these patterns handle 90% of production failure modes. Start with quality gates and canary deployments — they catch the most issues with the least infrastructure.
Based on ML deployment pipelines processing 50M+ predictions/day across multiple model versions.
Top comments (0)