ML Testing Framework
ML code breaks in ways that traditional software tests don't catch. Your model can pass every unit test and still fail catastrophically in production because the data distribution shifted, a feature pipeline started returning nulls, or the model performs well on average but terribly for a protected subgroup. This framework gives you test patterns built specifically for machine learning: data validation gates, model performance regression detection, fairness and bias checks, and integration tests that verify the full prediction pipeline end-to-end.
Key Features
- Data Validation Tests — Schema enforcement, null rate checks, distribution drift detection, and outlier flagging.
- Model Performance Tests — Assertions on accuracy, F1, AUC, latency with statistical significance thresholds.
- Fairness & Bias Checks — Demographic parity, equalized odds, and disparate impact ratio tests across protected attributes.
- Regression Detection — Compare new model versions against baselines, catching performance drops before deployment.
- Integration Tests — End-to-end pipeline tests: feature engineering → inference → post-processing → API response.
- Property-Based Tests — Invariance tests and directional expectation tests for model robustness.
Quick Start
unzip ml-testing-framework.zip && cd ml-testing-framework
pip install -r requirements.txt
# Run the full test suite
pytest tests/ -v --ml-report=report.html
# Run only data validation tests
pytest tests/ -v -m "data_validation"
# config.example.yaml
data_validation:
schema:
required_columns: [user_id, feature_1, feature_2, target]
column_types: { user_id: int, feature_1: float, feature_2: float, target: int }
quality:
max_null_rate: 0.05
max_duplicate_rate: 0.01
expected_row_count_min: 10000
model_performance:
metrics:
accuracy: { min_threshold: 0.85, regression_tolerance: 0.02 }
f1_macro: { min_threshold: 0.80 }
latency_p99_ms: { max_threshold: 100 }
fairness:
protected_attributes: [gender, age_group]
metrics:
demographic_parity_ratio: { min_threshold: 0.8 }
equalized_odds_difference: { max_threshold: 0.1 }
regression:
baseline_model_path: ./models/baseline_v1.pt
test_dataset: ./data/holdout_test.parquet
comparison_metrics: [accuracy, f1_macro, auc_roc]
Architecture
┌────────────────┐ ┌────────────────┐ ┌────────────────┐
│ Data │ │ Model │ │ Fairness │
│ Validation │ │ Performance │ │ & Bias │
│ Tests │ │ Tests │ │ Tests │
└───────┬────────┘ └───────┬────────┘ └───────┬────────┘
│ │ │
└──────────────────────┼───────────────────────┘
│
┌───────▼────────┐
│ Test Runner │
│ (pytest) │
└───────┬────────┘
│
┌──────────────┼──────────────┐
│ │ │
┌───────▼──────┐ ┌────▼─────┐ ┌──────▼──────┐
│ CI/CD │ │ HTML │ │ Regression │
│ Gate │ │ Report │ │ Tracker │
└──────────────┘ └──────────┘ └─────────────┘
Usage Examples
Data Validation Tests
import pytest
from ml_testing.core import DataValidator
@pytest.fixture
def validator():
return DataValidator.from_config("config.example.yaml")
def test_schema_compliance(validator, training_data):
"""Verify all required columns exist with correct types."""
result = validator.check_schema(training_data)
assert result.passed, f"Schema violations: {result.violations}"
def test_null_rates(validator, training_data):
"""No feature column should have more than 5% nulls."""
assert validator.check_null_rates(training_data, max_rate=0.05).passed
def test_feature_distribution(validator, training_data, reference_data):
"""Feature distributions shouldn't drift from reference."""
result = validator.check_distribution_drift(current=training_data, reference=reference_data,
method="ks_test", p_value_threshold=0.01)
assert result.passed, f"Drifted features: {result.drifted_columns}"
Model Performance Regression Tests
from ml_testing.core import ModelTester
tester = ModelTester()
def test_accuracy_above_threshold(trained_model, test_data):
"""Model accuracy must exceed minimum threshold."""
metrics = tester.evaluate(trained_model, test_data)
assert metrics["accuracy"] >= 0.85, f"Accuracy {metrics['accuracy']:.3f} below 0.85"
def test_no_regression_from_baseline(trained_model, baseline_model, test_data):
"""New model must not regress more than 2% from baseline."""
new_metrics = tester.evaluate(trained_model, test_data)
baseline_metrics = tester.evaluate(baseline_model, test_data)
regression = baseline_metrics["f1_macro"] - new_metrics["f1_macro"]
assert regression <= 0.02, f"F1 regression of {regression:.3f} exceeds 0.02"
Fairness and Bias Checks
from ml_testing.core import FairnessChecker
checker = FairnessChecker(protected_attributes=["gender", "age_group"])
def test_demographic_parity(trained_model, test_data):
"""Positive prediction rates should be similar across groups."""
result = checker.demographic_parity(model=trained_model, data=test_data, min_ratio=0.8)
for attr, report in result.items():
assert report["passed"], f"Demographic parity violation on '{attr}': ratio={report['ratio']:.3f}"
def test_equalized_odds(trained_model, test_data):
"""True positive rates should be similar across groups."""
result = checker.equalized_odds(trained_model, test_data, max_diff=0.1)
assert result.passed, f"Equalized odds violations: {result.details}"
Configuration Reference
| Parameter | Type | Default | Description |
|---|---|---|---|
data_validation.quality.max_null_rate |
float | 0.05 |
Maximum allowed null rate per column |
model_performance.metrics.accuracy.min_threshold |
float | 0.85 |
Minimum accuracy to pass |
model_performance.metrics.*.regression_tolerance |
float | 0.02 |
Max allowed drop from baseline |
fairness.metrics.demographic_parity_ratio.min_threshold |
float | 0.8 |
80% rule threshold |
regression.comparison_metrics |
list | [accuracy, f1_macro] |
Metrics to compare against baseline |
Best Practices
- Run data tests before model tests — If data is wrong, model metrics are meaningless. Gate your pipeline.
- Use statistical significance — A 0.5% accuracy diff on 100 samples isn't meaningful. Use confidence intervals.
- Test on slices, not just aggregates — 95% overall accuracy might hide 60% on a critical subgroup.
- Add invariance tests — Predictions shouldn't change for irrelevant input perturbations (whitespace, synonyms).
- Automate in CI — Run all tests on every PR that touches training code. Block merges on failure.
Troubleshooting
| Issue | Cause | Fix |
|---|---|---|
| Tests pass locally, fail in CI | Different data splits or random seeds | Set random_seed in config and pin dataset versions |
| Fairness tests too strict | Threshold set for balanced data on imbalanced dataset | Adjust min_ratio based on your data's natural class distribution |
| Performance test flaky | Small test set causes metric variance | Increase test set size or use bootstrap confidence intervals |
| Data drift test false positives | Seasonal patterns trigger KS test | Use a reference dataset from the same time period, or increase p_value_threshold
|
This is 1 of 11 resources in the ML Engineer Toolkit toolkit. Get the complete [ML Testing Framework] with all files, templates, and documentation for $29.
Or grab the entire ML Engineer Toolkit bundle (11 products) for $149 — save 30%.
Top comments (0)