DEV Community

Saee Barve
Saee Barve

Posted on

Why Your AI Model's Confidence Score Is Probably Lying (And What To Do About It)

The distribution shift problem that breaks modern AI in production explained for developers who actually deploy these things.

You trained the model. Metrics looked great. You deployed it. Six months later, something is quietly wrong but your accuracy dashboard looks fine.

What happened?

If you are running a modern AI system at scale, especially one using a Mixture-of-Experts architecture, there is a good chance your model's confidence scores have drifted out of alignment with reality. Not because the model got worse at prediction. Because the calibration broke silently, without error, without warning.

This post explains what that means, why it happens to MoE models specifically, and what you can do about it as a developer.

Quick Vocabulary Check

Before diving in, two terms you need:

Calibration: If your model says "I'm 80% confident," it should be correct 80% of the time it says that. A calibrated model's confidence scores are honest probability estimates. An uncalibrated model's confidence scores are basically noise.

Distribution shift: The data your model sees in production is not the same as the data it was trained on. The distribution of inputs drifts over time. This is not an edge case it is the normal state of any deployed model.

The Architecture: Mixture-of-Experts (MoE)

Most large-scale AI models today use MoE. The idea is simple:

Instead of one giant network, you have many specialized sub-networks called experts
A router looks at each input and decides which expert(s) handle it
This lets you scale model capacity without scaling compute linearly

Two flavors of routing:

Hard Routing:  input → router → ONE expert → output
Soft Routing:  input → router → weighted blend of MULTIPLE experts → output

Enter fullscreen mode Exit fullscreen mode

Soft routing is more expressive. It is also where calibration gets complicated.

The Problem: Perfectly Calibrated Experts, Broken Aggregate

Here is the scenario that should concern every ML engineer.

Suppose every expert in your MoE is individually well calibrated. When Expert A says 0.8, it is right 80% of the time. Same for Expert B, Expert C, all of them.

You might assume the combined model is also well-calibrated.

It is not under distribution shift.

Here is why.

With soft routing, your final prediction is:

f(x) = r1(x) * f1(x) + r2(x) * f2(x) + ... + rK(x) * fK(x)
Enter fullscreen mode Exit fullscreen mode

Where r1, r2, ...rK are routing weights and f1, f2, ...fK are expert predictions.

The same final score (say, 0.75) can come from completely different configurations:


Config A: r1=0.9, f1=0.75, r2=0.1, f2=0.75  → f(x) = 0.75
Config B: r1=0.5, f1=0.9,  r2=0.5, f2=0.6   → f(x) = 0.75
Config C: r1=0.3, f1=0.5,  r2=0.7, f2=0.89  → f(x) = 0.75
Enter fullscreen mode Exit fullscreen mode

On your training distribution, these configurations fire in certain proportions. Those proportions make the calibration work out — the deviations cancel, and 0.75 ends up being right 75% of the time.

Then distribution shift happens.

New data changes how often different types of inputs appear. Different routing configurations fire at different rates. The proportions that made calibration balance out no longer hold.

Now when the model says 0.75, maybe it is only right 58% of the time. Or 91% of the time. The confidence score has become unreliable — and you have no easy way to know from the outside.

Why Hard Routing Does Not Have This Problem

With hard routing, each input goes to exactly one expert. Your aggregate prediction is just that expert's prediction. The full routing information collapses to a simple pair: (which expert, what confidence).

If Expert 2 says 0.75, and Expert 2 is calibrated, then 0.75 is trustworthy regardless of whether the test distribution sends more or fewer inputs to Expert 2 than the training distribution did.

Hard routing is more robust to distribution shift in this specific dimension. The tradeoff is expressiveness: hard routing cannot capture cases where multiple experts' knowledge genuinely needs to be blended.

How Bad Can It Get?

The failure is worst on inputs that trigger the fragile configurations specifically the cases where:

Multiple experts receive substantial routing weight (not dominated by one expert)
Those experts disagree significantly in their predictions
The aggregate prediction therefore depends heavily on the exact routing weights

These are the cases where a mild shift in data distribution — one that does not change what the right answer is, does not change expert behavior, just changes how often certain input types appear can flip the calibration from reliable to useless.

And these are exactly the kinds of inputs where you most need reliable uncertainty estimates. If experts agree, you already have a signal. When experts disagree and you need the aggregate to guide you, that is when the calibration tends to be least trustworthy.

The Fix: Adversarial Reweighting During Training

The solution is to train the model to be calibrated not just on the average training distribution, but on stressed versions of that distribution.

The key insight: examples where the model has high loss are a proxy for the fragile configurations. These are the examples where routing weights create a shaky balance. If you train against adversarially reweighted distributions that emphasize high-loss examples, you make the model more robust where it needs to be.

In practice, this means using an exponential tilt during training:

# Conceptual implementation of Robust MoE training objective
def robust_moe_loss(losses, eta=1.0):
    """
    losses: per-example losses in the minibatch
    eta: tilt strength (higher = more emphasis on hard examples)
    """
    import torch

    # Compute entropy-balanced weights
    weights = torch.exp(eta * losses)
    weights = weights / weights.sum()  # normalize

    # Weighted loss emphasizes high-loss (fragile) examples
    robust_loss = (weights * losses).sum()

    return robust_loss

# Standard training loop modification
for batch_x, batch_y in dataloader:
    predictions = model(batch_x)

    # Per-example losses
    per_example_losses = criterion(predictions, batch_y, reduction='none')

    # Standard ERM loss
    # erm_loss = per_example_losses.mean()

    # Robust MoE loss - upweights hard examples
    loss = robust_moe_loss(per_example_losses, eta=0.5)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
Enter fullscreen mode Exit fullscreen mode

There is also a more targeted variant called Robust Filtered, which only applies the reweighting to routing-relevant examples — specifically:

Examples where the blended prediction is worse than the best individual expert
Examples where experts substantially disagree around the aggregate prediction

def robust_filtered_loss(losses, predictions, expert_predictions, routing_weights, eta=1.0):
    """
    Apply robust reweighting only to routing-relevant examples.
    """
    import torch

    # Find examples where blend is worse than best expert
    best_expert_loss = expert_predictions.min(dim=1).values  # simplified
    blend_worse = losses > best_expert_loss

    # Find examples where experts disagree substantially
    expert_variance = expert_predictions.var(dim=1)
    high_disagreement = expert_variance > expert_variance.median()

    # Routing-relevant subset
    routing_relevant = blend_worse | high_disagreement

    # ERM on full batch
    erm_loss = losses.mean()

    # Robust reweighting on routing-relevant subset
    if routing_relevant.sum() > 0:
        subset_losses = losses[routing_relevant]
        weights = torch.exp(eta * subset_losses)
        weights = weights / weights.sum()
        robust_term = (weights * subset_losses).sum()
    else:
        robust_term = 0.0

    return erm_loss + robust_term
Enter fullscreen mode Exit fullscreen mode

Both approaches consistently improve the calibration-accuracy tradeoff under distribution shift without a meaningful accuracy cost.

What To Do Right Now as a Developer

You might not be retraining your model today. Here is what you can do immediately:

  1. Add calibration monitoring to your eval pipeline
import numpy as np

def expected_calibration_error(y_true, y_prob, n_bins=10):
    """
    Compute Expected Calibration Error (ECE).
    Lower is better. 0 = perfect calibration.
    """
    bin_boundaries = np.linspace(0, 1, n_bins + 1)
    ece = 0.0

    for i in range(n_bins):
        lower, upper = bin_boundaries[i], bin_boundaries[i+1]
        mask = (y_prob >= lower) & (y_prob < upper)

        if mask.sum() == 0:
            continue

        bin_accuracy = y_true[mask].mean()
        bin_confidence = y_prob[mask].mean()
        bin_size = mask.sum()

        ece += (bin_size / len(y_true)) * abs(bin_accuracy - bin_confidence)

    return ece

# Add to your regular eval run
ece = expected_calibration_error(y_true, model_probabilities)
print(f"ECE: {ece:.4f}")  # flag if this creeps up over time
Enter fullscreen mode Exit fullscreen mode
  1. Plot reliability diagrams regularly
import matplotlib.pyplot as plt
from sklearn.calibration import calibration_curve

def plot_reliability_diagram(y_true, y_prob, title="Reliability Diagram"):
    fraction_of_positives, mean_predicted_value = calibration_curve(
        y_true, y_prob, n_bins=10
    )

    plt.figure(figsize=(8, 6))
    plt.plot([0, 1], [0, 1], 'k--', label='Perfect calibration')
    plt.plot(mean_predicted_value, fraction_of_positives, 
             's-', label='Model')
    plt.xlabel('Mean predicted probability')
    plt.ylabel('Fraction of positives')
    plt.title(title)
    plt.legend()
    plt.show()
Enter fullscreen mode Exit fullscreen mode

A model drifting toward overconfidence will show a curve that bends below the diagonal. Catch this early.

  1. Track input distribution drift
from scipy.stats import ks_2samp

def detect_distribution_shift(train_features, current_features, threshold=0.05):
    """
    Kolmogorov-Smirnov test for distribution shift per feature.
    Flag features where p-value < threshold.
    """
    shifted_features = []

    for i in range(train_features.shape[1]):
        stat, p_value = ks_2samp(train_features[:, i], current_features[:, i])
        if p_value < threshold:
            shifted_features.append({
                'feature_index': i,
                'ks_statistic': stat,
                'p_value': p_value
            })

    return shifted_features
Enter fullscreen mode Exit fullscreen mode
  1. Use temperature scaling as a quick post-hoc fix

If you cannot retrain, temperature scaling is the fastest way to recalibrate a model after deployment:

import torch
import torch.nn as nn

class TemperatureScaler(nn.Module):
    def __init__(self):
        super().__init__()
        self.temperature = nn.Parameter(torch.ones(1))

    def forward(self, logits):
        return logits / self.temperature

    def fit(self, logits, labels, lr=0.01, max_iter=50):
        optimizer = torch.optim.LBFGS([self.temperature], lr=lr, max_iter=max_iter)
        criterion = nn.CrossEntropyLoss()

        def eval_step():
            optimizer.zero_grad()
            loss = criterion(self.forward(logits), labels)
            loss.backward()
            return loss

 **       optimizer.step(eval_step)
 **       return self
Enter fullscreen mode Exit fullscreen mode

Note: temperature scaling helps on average but does not address the subset-specific calibration failures from distribution shift. It is a patch, not a solution.

Summary

Routing TypeCalibration Under ShiftWhyHard routingRobust ✅Calibration depends only on (expert, confidence) pairSoft routingFragile ⚠️Different configurations collapse to same score; shift changes their balance

The fix: Train with adversarial reweighting (Robust MoE or Robust Filtered) to stress the model on its hardest examples. At minimum, monitor ECE and distribution shift in production.

The deeper lesson: calibration is a system-level property. Calibrated parts do not automatically combine into a calibrated whole — especially when distribution shift changes how those parts interact.

Have you dealt with calibration drift in production? What monitoring setup worked for you? Drop it in the comments.

Top comments (0)