DEV Community

Cover image for How to Measure Whether Your Model's Uncertainty Space Is Flat or Curved
felipe muniz
felipe muniz

Posted on

How to Measure Whether Your Model's Uncertainty Space Is Flat or Curved

A practical guide to Riemannian epistemic geometry in language models, with code.


Most calibration research treats uncertainty as a scalar or a vector. You compute a confidence score, you compare it to ground truth, you minimize ECE. The space in which that uncertainty lives is assumed to be flat.

That assumption might be wrong. And if it is wrong, it has concrete consequences for out-of-distribution detection, adversarial robustness, and AI safety.

This post explains how to test it, using code from my current research on AletheionLLM-v2.


The baseline: diagonal distance in a 5D epistemic manifold

AletheionLLM-v2 is a 354M parameter decoder-only LLM with an integrated epistemic architecture called ATIC. Instead of producing a single confidence score, the model maintains a 5-dimensional manifold where each axis represents a distinct component of uncertainty, learned via BayesianTau.

The current distance metric (branch main) is diagonal:

def distance_diagonal(x1, x2, tau_sq):
    diff = x1 - x2
    tau_sq_safe = np.maximum(tau_sq, 1e-8)
    return np.sqrt(np.sum(diff**2 / tau_sq_safe))
Enter fullscreen mode Exit fullscreen mode

Each axis has its own learned variance. The axes are independent. The space is R5, rescaled.

This already works well. ECE 0.0176, Brier Score 0.1528, best-in-class on OOD WikiText-103, outperforming GPT-2 Medium and OPT-350M on epistemic calibration.

But there is a question the diagonal cannot answer: does the epistemic space have curvature?


Why curvature is a different question from correlation

Before going further, one distinction matters.

A full Mahalanobis metric, where G is a constant 5x5 matrix learned via Cholesky decomposition, captures correlations between epistemic dimensions. That is useful. But it does not produce curvature.

If G is constant, then the Christoffel symbols are all zero:

Gamma^k_ij = (1/2) g^kl (d_i g_jl + d_j g_il - d_l g_ij) = 0
Enter fullscreen mode Exit fullscreen mode

Zero Christoffel symbols means zero Riemann curvature. The space is still flat, just with oblique coordinates. Geodesics are still straight lines.

For real curvature, G must vary with position. G(x) must be a tensor field, not a constant matrix.


Branch real_geodesic: making the metric a field

In the real_geodesic branch, a lightweight network (5 -> 32 -> 15, roughly 700 parameters) produces a position-dependent SPD tensor at every point in the manifold:

class MetricNet(nn.Module):
    def __init__(self, dim=5, hidden_dim=32):
        super().__init__()
        self.dim = dim
        self.n_chol = dim * (dim + 1) // 2  # 15 for dim=5

        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.Tanh(),  # Tanh, not ReLU -- G(x) must be smooth (C1)
            nn.Linear(hidden_dim, self.n_chol),
        )

        # Zero init on last layer -> G(x) ~ I at start
        nn.init.zeros_(self.net[-1].weight)
        nn.init.zeros_(self.net[-1].bias)

        # Pre-computed indices for lower triangular construction
        tril_idx = torch.tril_indices(dim, dim)
        self.register_buffer("tril_row", tril_idx[0])
        self.register_buffer("tril_col", tril_idx[1])
        self.register_buffer("diag_idx", torch.arange(dim))

    def forward(self, coords):
        """coords: [..., 5] -> G: [..., 5, 5] SPD"""
        raw = self.net(coords)  # [..., 15]
        batch_shape = raw.shape[:-1]

        L = torch.zeros(*batch_shape, self.dim, self.dim,
                         device=raw.device, dtype=raw.dtype)
        L[..., self.tril_row, self.tril_col] = raw

        # Positive diagonal via softplus + offset (not exp -- more stable)
        L[..., self.diag_idx, self.diag_idx] = (
            F.softplus(L[..., self.diag_idx, self.diag_idx]) + 1e-3
        )

        return torch.matmul(L, L.transpose(-1, -2))  # SPD guaranteed
Enter fullscreen mode Exit fullscreen mode

Key design choices:

  • Tanh activation instead of ReLU. G(x) is a metric field -- it must be smooth. ReLU creates non-differentiable points that would make the Christoffel symbols undefined.
  • softplus + 1e-3 on diagonal instead of exp. More numerically stable during training, avoids gradient explosion.
  • Zero init on last layer. At initialization, the network outputs zeros for all inputs, so G(x) starts as approximately 0.48 * I everywhere. Training starts stable.

Distance between two epistemic states is a line integral computed via Gauss-Legendre quadrature:

def line_integral_distance(self, p, q):
    """p: [B, T, 5], q: [5] -> distance: [B, T, 1]"""
    if q.dim() == 1:
        q = q.unsqueeze(0).unsqueeze(0).expand_as(p)

    delta = q - p
    total = torch.zeros(p.shape[0], p.shape[1], 1,
                         device=p.device, dtype=p.dtype)

    for i in range(self.n_quad):
        t = self.gl_points[i]
        w = self.gl_weights[i]

        x_t = p + t * delta           # point along straight line
        G_t = self.forward(x_t)       # G(x) at that point
        Gd = torch.matmul(delta.unsqueeze(-2), G_t).squeeze(-2)
        integrand = (Gd * delta).sum(dim=-1, keepdim=True)
        total = total + w * torch.sqrt(integrand.clamp(min=1e-8))

    return total
Enter fullscreen mode Exit fullscreen mode

One clarification worth being explicit about: this computes the length of the straight line between p and q under the varying metric, not the true geodesic (which would minimize path length and would be shorter). The true geodesic requires a shooting method or ODE solver. The straight-line approximation is differentiable, cheap (5 evaluations of MetricNet per distance), and sufficient to detect whether G(x) varies along the path -- which is the primary question.

When G depends on position, the Christoffel symbols are no longer zero. Geodesics are curves. The space has intrinsic curvature.


The experiment: three branches, one falsifiable question

Branch Metric Geometry
main G = diag(tau) Flat, orthogonal axes
full_mahalanobis G = constant 5x5 Flat, oblique axes
real_geodesic G(x) = learned field Potentially curved

The test uses three categories of input pairs:

probes = {
    "high_confidence": [
        ("The capital of France is", "Paris"),
        ("2 + 2 =", "4"),
    ],
    "low_confidence": [
        ("The exact number of neurons in the human brain is", "86"),
    ],
    "context_sensitive": [
        ("The bank was steep and", "muddy"),    # bank = riverbank
        ("The bank was closed and", "dark"),    # bank = institution
        ("He left the plant near", "water"),    # plant = vegetation
        ("He left the plant near", "the door"), # plant = factory
    ]
}
Enter fullscreen mode Exit fullscreen mode

The context-sensitive pairs are the key. Same surface token, different semantic region of the manifold. If G(x) learned real structure, the geodesic distance between "bank=riverbank" and "bank=institution" will be larger than the distance between two within-domain contexts, even though the diagonal distance would treat them similarly.


Detecting curvature directly: metric variation along a path

def measure_metric_variation(metric_net, x_start, x_end, n_samples=20):
    G_samples = []

    for t in np.linspace(0, 1, n_samples):
        x_t = x_start + t * (x_end - x_start)
        x_tensor = torch.tensor(x_t, dtype=torch.float32).to(device)
        G_t = metric_net(x_tensor.unsqueeze(0).unsqueeze(0))
        G_samples.append(G_t[0, 0].cpu().numpy())

    G_stack = np.stack(G_samples)
    variation = np.std(G_stack, axis=0)

    print(f"Mean metric variation: {variation.mean():.6f}")
    print(f"Max element variation: {variation.max():.6f}")
    print(f"Verdict: {'CURVED' if variation.max() > 0.01 else 'FLAT'}")

    return variation
Enter fullscreen mode Exit fullscreen mode

If G varies along the path from a high-confidence state to a low-confidence state, the manifold has non-trivial local geometry. If it converges to a constant, the diagonal was correct for a fundamental reason, not an approximation.


What each result means

If real_geodesic learns G(x) approximately constant:

The epistemic manifold of a 354M LLM is intrinsically flat. The diagonal metric was not a lazy approximation. It was geometrically correct. ECE 0.0176 reflects genuine calibration, not a subspace artifact.

If G(x) learns structural variation:

There are regions of the manifold with distinct geometry. Two epistemic states that appear equidistant in diagonal coordinates may have very different geodesic distances. This has direct consequences:

  • OOD detection gains a geometric signal. Inputs that land in high-curvature regions are structurally anomalous, regardless of whether similar inputs appeared in red-teaming.
  • Calibration thresholds become local, not global. Flat regions warrant confidence. High-curvature regions warrant conservatism, and the geometry says which is which before seeing ground truth.
  • The training corpus leaves a geometric signature. A model trained on harmful content does not become malevolent. It becomes a system where harmful outputs are geometrically cheap, because the manifold is flat and well-sampled there. That is a structurally different and more concerning failure mode than explicit harmful intent.

Training considerations

The MetricNet adds ~700 parameters to a 354M model. The gradient signal reaching those parameters is inherently weak. Two measures address this:

1. Separate learning rate. MetricNet gets 10x the base LR (5e-4 vs 5e-5). Without this, G(x) may converge to identity not because the space is flat, but because the signal was too weak to learn structure.

2. Smoothness regularization. A penalty on the variation of G under small perturbations of the input coordinates:

def metric_smoothness_loss(metric_net, coords, eps=0.01):
    G = metric_net(coords)
    noise = torch.randn_like(coords) * eps
    G_perturbed = metric_net((coords + noise).clamp(0, 1))
    return (G - G_perturbed.detach()).pow(2).sum(dim=(-2, -1)).mean()
Enter fullscreen mode Exit fullscreen mode

Without this, G(x) can learn discontinuities that make the line integral numerically unstable and gradients noisy.


A note on quadrature stability

The implementation uses 5 Gauss-Legendre points by default, with pre-computed nodes and weights for efficiency. Tanh activation makes high-frequency variation unlikely, but you can verify convergence:

def check_quadrature_convergence(metric_net, x1, x2,
                                  n_points_list=[5, 8, 16]):
    for n in n_points_list:
        t_nodes, weights = np.polynomial.legendre.leggauss(n)
        t_nodes = (t_nodes + 1) / 2
        weights = weights / 2

        dx = x2 - x1
        total = 0.0
        for t, w in zip(t_nodes, weights):
            x_t = x1 + t * dx
            x_tensor = torch.tensor(x_t, dtype=torch.float32).to(device)
            G_t = metric_net(x_tensor.unsqueeze(0).unsqueeze(0))
            G_np = G_t[0, 0].cpu().numpy()
            ds2 = dx @ G_np @ dx
            total += w * np.sqrt(max(ds2, 1e-12))

        print(f"  n={n:2d}: distance = {total:.6f}")
Enter fullscreen mode Exit fullscreen mode

If distance does not stabilize between 5 and 16 points, the metric has high-frequency local variation. With Tanh, 5 points should be sufficient for most manifold geometries.


Current status and reproducibility

All three branches are live in the public repository. The baseline (branch main) is fully reproducible: training code, evaluation scripts, and the paper with full methodology are all public.

Results from the three-branch comparison will be published here and on ResearchGate when training is complete.

If you are working on calibration, OOD detection, or geometric approaches to uncertainty in language models, I am interested in talking. The repository is open and the methodology is fully documented.


Felipe Maya Muniz is the founder of AletheionAGI and independent researcher developing ATIC, a geometric cognitive architecture for epistemic self-awareness in AI systems.

Top comments (0)