KNN retrieval shows up in a lot of risk control pipelines. The general idea is simple: embed your content, find its nearest neighbors in a labeled seed pool, and aggregate the neighbor labels into a risk score. In practice, getting this to work well involves a surprising number of decisions that aren't obvious upfront.
This post covers the full diagnostic and tuning process — from auditing your seed pool to searching over scoring functions — based on running through 50+ scoring combinations on a real retrieval system.
The Setup
The retrieval system works like this:
- A query item (user, ad, content piece) is embedded into a vector.
- An ANN index returns the top-N nearest neighbors from a seed pool.
- Each neighbor carries a label (black = risky, white = benign) and a similarity score.
- These get aggregated into a single risk score, which is thresholded to produce a binary decision.
The interesting part is step 4. There are many reasonable ways to aggregate, and they behave very differently depending on the shape of your data.
Step 1: Audit Your Seed Pool First
Before touching the scoring function, check two things about your seed pool.
Class Balance
seed_df.groupBy("is_black").count().orderBy("is_black").show()
This matters more than most people expect. If your black-to-white ratio is above 10:1, almost every query will retrieve a neighborhood that is overwhelmingly black regardless of the query's true risk level. In that regime, ratio-based scoring functions become nearly useless — the signal is drowned out by the prior.
A rough rule of thumb: ratios above 10:1 mean you should fix the seed pool before tuning the scoring function. Ratios below 5:1 give the scoring function enough to work with.
Embedding Separability
Even with a balanced pool, the scoring function can only work if black and white seeds actually occupy different regions of the embedding space. Check this with a few cosine similarity statistics:
black_center = X[y == 1].mean(axis=0)
white_center = X[y == 0].mean(axis=0)
center_sim = cosine_similarity([black_center], [white_center])[0][0]
# Sample to avoid memory issues on large pools
black_sample = X[y == 1][np.random.choice(sum(y==1), 500, replace=False)]
white_sample = X[y == 0]
intra_black = cosine_similarity(black_sample).mean()
intra_white = cosine_similarity(white_sample).mean()
inter = cosine_similarity(black_sample, white_sample).mean()
print(f"center similarity: {center_sim:.4f}")
print(f"intra-black: {intra_black:.4f}")
print(f"intra-white: {intra_white:.4f}")
print(f"inter-class: {inter:.4f}")
The number to watch is whether inter < intra. If inter-class similarity is lower than intra-class, the two classes are separable in embedding space and a good scoring function can exploit that. If they are roughly equal, the embedding itself is the bottleneck — no scoring trick will fix it.
Step 2: Check Candidate Composition
Before evaluating any scoring function, look at what your retrieved neighborhoods actually look like.
pdf["black_count"] = pdf["candidate_labels"].apply(sum)
pdf["white_count"] = pdf["total_candidates"] - pdf["black_count"]
print(pdf[["total_candidates", "black_count", "white_count"]].describe())
Also compare the similarity score distributions for truly risky vs benign queries:
print("risky queries:", pdf[pdf["gt_label"]==1]["score_max_black"].describe())
print("benign queries:", pdf[pdf["gt_label"]==0]["score_max_black"].describe())
If the mean difference between the two distributions is less than 0.05, the retrieval signal itself is weak and you are unlikely to get a useful scoring function out of it regardless of what aggregation you use. A mean difference above 0.10 is a reasonable starting point.
Step 3: Similarity Filtering
One of the clearest findings from running these experiments: filtering out low-similarity neighbors before scoring makes a large difference.
The candidates returned by ANN retrieval are not all equally informative. Neighbors with similarity below ~0.66 tend to be coincidental matches — they share some superficial feature but are not meaningfully related to the query. Including them adds noise without adding signal.
We tested four filter thresholds:
| thr | Effect |
|---|---|
| 0.0 | All neighbors included; high noise |
| 0.3 | Minimal effect |
| 0.5 | Slight improvement |
| 0.7 | Largest consistent gain |
thr=0.7 was the most impactful single hyperparameter across all scoring functions tested. It is worth always applying this as a baseline. Values of 0.75 or 0.80 can squeeze out a bit more precision at the cost of some recall, and are worth trying if precision is the priority.
Step 4: Scoring Functions
With filtering in place, the choice of aggregation function becomes the main lever. Here is a tour of what was tested and what worked.
Majority Vote
def majority_vote(scores, labels, thr=0.7):
valid = [l for s, l in zip(scores, labels) if s >= thr]
return sum(valid) / len(valid) if valid else 0.0
Equal weight to all neighbors above threshold. This is the natural baseline. Its weakness: a query that happens to land near a cluster of marginally-similar black seeds will score high even if its closest neighbors are actually benign.
Power Decay
def power_decay(scores, labels, p=8, thr=0.7):
weights, black_w = 0.0, 0.0
for s, l in zip(scores, labels):
if s < thr:
continue
d = max(1 - s, 1e-6)
w = 1.0 / (d ** p)
weights += w
black_w += w * l
return black_w / weights if weights > 0 else 0.0
Weight each neighbor by 1 / distance^p. Higher p concentrates influence on the closest neighbors — at p=8, a neighbor at distance 0.1 outweighs one at distance 0.3 by a factor of (0.3/0.1)^8 = 6561. In practice this means only the two or three nearest neighbors actually matter.
p=8 with thr=0.7 was the best single configuration tested, outperforming the baseline by 24% on F1 at recall ≥ 0.5.
Exponential and Gaussian Decay
def exp_decay(scores, labels, sigma=0.1, thr=0.7):
weights, black_w = 0.0, 0.0
for s, l in zip(scores, labels):
if s < thr:
continue
d = 1 - s
w = np.exp(-d / sigma)
weights += w
black_w += w * l
return black_w / weights if weights > 0 else 0.0
def gaussian(scores, labels, sigma=0.1, thr=0.7):
weights, black_w = 0.0, 0.0
for s, l in zip(scores, labels):
if s < thr:
continue
d = 1 - s
w = np.exp(-(d ** 2) / (2 * sigma ** 2))
weights += w
black_w += w * l
return black_w / weights if weights > 0 else 0.0
Both behave similarly to power decay at small sigma values. Gaussian decay with sigma=0.1 is roughly equivalent to power_decay_p8. They did not outperform power decay in these experiments but are worth including in a sweep since they have different sensitivity profiles.
Top-K Ratio
def topk_ratio(scores, labels, k=30, thr=0.7):
pairs = sorted(zip(scores, labels), reverse=True)[:k]
valid = [l for s, l in pairs if s >= thr]
return sum(valid) / len(valid) if valid else 0.0
Take only the k closest neighbors and compute the black fraction. This is a simpler version of the distance-weighting idea — instead of downweighting distant neighbors, just ignore them entirely. k=10 and k=30 both performed well. Worth comparing directly against power decay since it is easier to reason about.
Top-K + Power Decay
def topk_power_decay(scores, labels, k=30, p=8, thr=0.7):
pairs = sorted(zip(scores, labels), reverse=True)[:k]
weights, black_w = 0.0, 0.0
for s, l in pairs:
if s < thr:
continue
d = max(1 - s, 1e-6)
w = 1.0 / (d ** p)
weights += w
black_w += w * l
return black_w / weights if weights > 0 else 0.0
Combines both ideas: hard cutoff at k neighbors, then distance weighting within that set. Slightly smoother than pure top-k ratio, slightly less sensitive to distant noise than full power decay. A reasonable default if you are starting from scratch.
Rank-Based Scoring
def rank_based(scores, labels, thr=0.7):
pairs = sorted(zip(scores, labels), reverse=True)
weights, black_w = 0.0, 0.0
for rank, (s, l) in enumerate(pairs, start=1):
if s < thr:
break
w = 1.0 / rank
weights += w
black_w += w * l
return black_w / weights if weights > 0 else 0.0
Uses rank rather than distance as the weighting signal. This is more robust to embedding normalization issues — if your similarity scores are not well-calibrated, the relative ordering is more reliable than the absolute values. Worth testing when distance-based methods underperform.
Step 5: Evaluation
Evaluate each combination on a held-out set with ground truth labels. AUC captures overall ranking quality; F1 under a recall constraint captures real-world operating points.
from sklearn.metrics import roc_auc_score, precision_recall_curve
import numpy as np
def evaluate(scores, labels, recall_floor=0.5):
auc = roc_auc_score(labels, scores)
precision, recall, thresholds = precision_recall_curve(labels, scores)
f1 = 2 * precision[:-1] * recall[:-1] / (precision[:-1] + recall[:-1] + 1e-9)
mask = recall[:-1] >= recall_floor
best = np.argmax(f1 * mask) if mask.any() else np.argmax(f1)
return {
"auc": round(auc, 4),
"threshold": round(thresholds[best], 4),
"precision": round(float(precision[best]), 4),
"recall": round(float(recall[best]), 4),
"f1": round(float(f1[best]), 4),
}
Run this for each (scoring function, thr, p/sigma/k) combination and sort by AUC. In practice you will find that the best configuration is consistent across evaluation dates — if it is not, the variation is usually driven by seed pool drift rather than scoring function instability.
Step 6: Optional Calibration
Power decay scores are not well-calibrated probabilities — they tend to cluster near the extremes. If you need consistent thresholds across different retrieval keys or want to compare scores across systems, isotonic regression calibration is a low-effort improvement.
from sklearn.isotonic import IsotonicRegression
ir = IsotonicRegression(out_of_bounds="clip")
ir.fit(raw_scores_train, gt_labels_train)
calibrated = ir.predict(raw_scores_test)
This will not improve AUC (it is a monotone transform) but makes threshold selection more intuitive and stable.
Summary
| What to check | Threshold | Action |
|---|---|---|
| Black/white seed ratio | > 10:1 | Fix seed pool before tuning |
| Embedding inter vs intra similarity | inter ≈ intra | Retrain embedding |
| Candidate neighborhood composition | > 95% black | Seed pool imbalance is the bottleneck |
| Similarity score distribution gap | < 0.05 | Retrieval signal too weak |
| Best filter threshold | 0.7 | Apply as default; try 0.75/0.80 for precision |
| Best scoring function | power_decay p=8 | Try topk+power_decay as alternative |
The pattern that comes up repeatedly: scoring function tuning has diminishing returns if the seed pool is imbalanced or the embedding is not separable. Those two things should be verified first. Given a healthy seed pool and separable embeddings, the jump from majority vote to power_decay_p8_thr0.7 is meaningful — roughly 20-25% F1 improvement in the experiments described here — and the combination search is worth running.
Top comments (0)