DEV Community

Dany W
Dany W

Posted on

KNN-Based Risk Scoring: What Actually Works

KNN retrieval shows up in a lot of risk control pipelines. The general idea is simple: embed your content, find its nearest neighbors in a labeled seed pool, and aggregate the neighbor labels into a risk score. In practice, getting this to work well involves a surprising number of decisions that aren't obvious upfront.

This post covers the full diagnostic and tuning process — from auditing your seed pool to searching over scoring functions — based on running through 50+ scoring combinations on a real retrieval system.


The Setup

The retrieval system works like this:

  1. A query item (user, ad, content piece) is embedded into a vector.
  2. An ANN index returns the top-N nearest neighbors from a seed pool.
  3. Each neighbor carries a label (black = risky, white = benign) and a similarity score.
  4. These get aggregated into a single risk score, which is thresholded to produce a binary decision.

The interesting part is step 4. There are many reasonable ways to aggregate, and they behave very differently depending on the shape of your data.


Step 1: Audit Your Seed Pool First

Before touching the scoring function, check two things about your seed pool.

Class Balance

seed_df.groupBy("is_black").count().orderBy("is_black").show()
Enter fullscreen mode Exit fullscreen mode

This matters more than most people expect. If your black-to-white ratio is above 10:1, almost every query will retrieve a neighborhood that is overwhelmingly black regardless of the query's true risk level. In that regime, ratio-based scoring functions become nearly useless — the signal is drowned out by the prior.

A rough rule of thumb: ratios above 10:1 mean you should fix the seed pool before tuning the scoring function. Ratios below 5:1 give the scoring function enough to work with.

Embedding Separability

Even with a balanced pool, the scoring function can only work if black and white seeds actually occupy different regions of the embedding space. Check this with a few cosine similarity statistics:

black_center = X[y == 1].mean(axis=0)
white_center = X[y == 0].mean(axis=0)
center_sim = cosine_similarity([black_center], [white_center])[0][0]

# Sample to avoid memory issues on large pools
black_sample = X[y == 1][np.random.choice(sum(y==1), 500, replace=False)]
white_sample = X[y == 0]

intra_black = cosine_similarity(black_sample).mean()
intra_white = cosine_similarity(white_sample).mean()
inter       = cosine_similarity(black_sample, white_sample).mean()

print(f"center similarity:  {center_sim:.4f}")
print(f"intra-black:        {intra_black:.4f}")
print(f"intra-white:        {intra_white:.4f}")
print(f"inter-class:        {inter:.4f}")
Enter fullscreen mode Exit fullscreen mode

The number to watch is whether inter < intra. If inter-class similarity is lower than intra-class, the two classes are separable in embedding space and a good scoring function can exploit that. If they are roughly equal, the embedding itself is the bottleneck — no scoring trick will fix it.


Step 2: Check Candidate Composition

Before evaluating any scoring function, look at what your retrieved neighborhoods actually look like.

pdf["black_count"] = pdf["candidate_labels"].apply(sum)
pdf["white_count"] = pdf["total_candidates"] - pdf["black_count"]
print(pdf[["total_candidates", "black_count", "white_count"]].describe())
Enter fullscreen mode Exit fullscreen mode

Also compare the similarity score distributions for truly risky vs benign queries:

print("risky queries:",  pdf[pdf["gt_label"]==1]["score_max_black"].describe())
print("benign queries:", pdf[pdf["gt_label"]==0]["score_max_black"].describe())
Enter fullscreen mode Exit fullscreen mode

If the mean difference between the two distributions is less than 0.05, the retrieval signal itself is weak and you are unlikely to get a useful scoring function out of it regardless of what aggregation you use. A mean difference above 0.10 is a reasonable starting point.


Step 3: Similarity Filtering

One of the clearest findings from running these experiments: filtering out low-similarity neighbors before scoring makes a large difference.

The candidates returned by ANN retrieval are not all equally informative. Neighbors with similarity below ~0.66 tend to be coincidental matches — they share some superficial feature but are not meaningfully related to the query. Including them adds noise without adding signal.

We tested four filter thresholds:

thr Effect
0.0 All neighbors included; high noise
0.3 Minimal effect
0.5 Slight improvement
0.7 Largest consistent gain

thr=0.7 was the most impactful single hyperparameter across all scoring functions tested. It is worth always applying this as a baseline. Values of 0.75 or 0.80 can squeeze out a bit more precision at the cost of some recall, and are worth trying if precision is the priority.


Step 4: Scoring Functions

With filtering in place, the choice of aggregation function becomes the main lever. Here is a tour of what was tested and what worked.

Majority Vote

def majority_vote(scores, labels, thr=0.7):
    valid = [l for s, l in zip(scores, labels) if s >= thr]
    return sum(valid) / len(valid) if valid else 0.0
Enter fullscreen mode Exit fullscreen mode

Equal weight to all neighbors above threshold. This is the natural baseline. Its weakness: a query that happens to land near a cluster of marginally-similar black seeds will score high even if its closest neighbors are actually benign.

Power Decay

def power_decay(scores, labels, p=8, thr=0.7):
    weights, black_w = 0.0, 0.0
    for s, l in zip(scores, labels):
        if s < thr:
            continue
        d = max(1 - s, 1e-6)
        w = 1.0 / (d ** p)
        weights += w
        black_w += w * l
    return black_w / weights if weights > 0 else 0.0
Enter fullscreen mode Exit fullscreen mode

Weight each neighbor by 1 / distance^p. Higher p concentrates influence on the closest neighbors — at p=8, a neighbor at distance 0.1 outweighs one at distance 0.3 by a factor of (0.3/0.1)^8 = 6561. In practice this means only the two or three nearest neighbors actually matter.

p=8 with thr=0.7 was the best single configuration tested, outperforming the baseline by 24% on F1 at recall ≥ 0.5.

Exponential and Gaussian Decay

def exp_decay(scores, labels, sigma=0.1, thr=0.7):
    weights, black_w = 0.0, 0.0
    for s, l in zip(scores, labels):
        if s < thr:
            continue
        d = 1 - s
        w = np.exp(-d / sigma)
        weights += w
        black_w += w * l
    return black_w / weights if weights > 0 else 0.0

def gaussian(scores, labels, sigma=0.1, thr=0.7):
    weights, black_w = 0.0, 0.0
    for s, l in zip(scores, labels):
        if s < thr:
            continue
        d = 1 - s
        w = np.exp(-(d ** 2) / (2 * sigma ** 2))
        weights += w
        black_w += w * l
    return black_w / weights if weights > 0 else 0.0
Enter fullscreen mode Exit fullscreen mode

Both behave similarly to power decay at small sigma values. Gaussian decay with sigma=0.1 is roughly equivalent to power_decay_p8. They did not outperform power decay in these experiments but are worth including in a sweep since they have different sensitivity profiles.

Top-K Ratio

def topk_ratio(scores, labels, k=30, thr=0.7):
    pairs = sorted(zip(scores, labels), reverse=True)[:k]
    valid = [l for s, l in pairs if s >= thr]
    return sum(valid) / len(valid) if valid else 0.0
Enter fullscreen mode Exit fullscreen mode

Take only the k closest neighbors and compute the black fraction. This is a simpler version of the distance-weighting idea — instead of downweighting distant neighbors, just ignore them entirely. k=10 and k=30 both performed well. Worth comparing directly against power decay since it is easier to reason about.

Top-K + Power Decay

def topk_power_decay(scores, labels, k=30, p=8, thr=0.7):
    pairs = sorted(zip(scores, labels), reverse=True)[:k]
    weights, black_w = 0.0, 0.0
    for s, l in pairs:
        if s < thr:
            continue
        d = max(1 - s, 1e-6)
        w = 1.0 / (d ** p)
        weights += w
        black_w += w * l
    return black_w / weights if weights > 0 else 0.0
Enter fullscreen mode Exit fullscreen mode

Combines both ideas: hard cutoff at k neighbors, then distance weighting within that set. Slightly smoother than pure top-k ratio, slightly less sensitive to distant noise than full power decay. A reasonable default if you are starting from scratch.

Rank-Based Scoring

def rank_based(scores, labels, thr=0.7):
    pairs = sorted(zip(scores, labels), reverse=True)
    weights, black_w = 0.0, 0.0
    for rank, (s, l) in enumerate(pairs, start=1):
        if s < thr:
            break
        w = 1.0 / rank
        weights += w
        black_w += w * l
    return black_w / weights if weights > 0 else 0.0
Enter fullscreen mode Exit fullscreen mode

Uses rank rather than distance as the weighting signal. This is more robust to embedding normalization issues — if your similarity scores are not well-calibrated, the relative ordering is more reliable than the absolute values. Worth testing when distance-based methods underperform.


Step 5: Evaluation

Evaluate each combination on a held-out set with ground truth labels. AUC captures overall ranking quality; F1 under a recall constraint captures real-world operating points.

from sklearn.metrics import roc_auc_score, precision_recall_curve
import numpy as np

def evaluate(scores, labels, recall_floor=0.5):
    auc = roc_auc_score(labels, scores)
    precision, recall, thresholds = precision_recall_curve(labels, scores)
    f1 = 2 * precision[:-1] * recall[:-1] / (precision[:-1] + recall[:-1] + 1e-9)
    mask = recall[:-1] >= recall_floor
    best = np.argmax(f1 * mask) if mask.any() else np.argmax(f1)
    return {
        "auc":       round(auc, 4),
        "threshold": round(thresholds[best], 4),
        "precision": round(float(precision[best]), 4),
        "recall":    round(float(recall[best]), 4),
        "f1":        round(float(f1[best]), 4),
    }
Enter fullscreen mode Exit fullscreen mode

Run this for each (scoring function, thr, p/sigma/k) combination and sort by AUC. In practice you will find that the best configuration is consistent across evaluation dates — if it is not, the variation is usually driven by seed pool drift rather than scoring function instability.


Step 6: Optional Calibration

Power decay scores are not well-calibrated probabilities — they tend to cluster near the extremes. If you need consistent thresholds across different retrieval keys or want to compare scores across systems, isotonic regression calibration is a low-effort improvement.

from sklearn.isotonic import IsotonicRegression

ir = IsotonicRegression(out_of_bounds="clip")
ir.fit(raw_scores_train, gt_labels_train)
calibrated = ir.predict(raw_scores_test)
Enter fullscreen mode Exit fullscreen mode

This will not improve AUC (it is a monotone transform) but makes threshold selection more intuitive and stable.


Summary

What to check Threshold Action
Black/white seed ratio > 10:1 Fix seed pool before tuning
Embedding inter vs intra similarity inter ≈ intra Retrain embedding
Candidate neighborhood composition > 95% black Seed pool imbalance is the bottleneck
Similarity score distribution gap < 0.05 Retrieval signal too weak
Best filter threshold 0.7 Apply as default; try 0.75/0.80 for precision
Best scoring function power_decay p=8 Try topk+power_decay as alternative

The pattern that comes up repeatedly: scoring function tuning has diminishing returns if the seed pool is imbalanced or the embedding is not separable. Those two things should be verified first. Given a healthy seed pool and separable embeddings, the jump from majority vote to power_decay_p8_thr0.7 is meaningful — roughly 20-25% F1 improvement in the experiments described here — and the combination search is worth running.

Top comments (0)