If you've ever sat there watching sklearn.cluster.KMeans churn through a large dataset while your laptop fan spins up like a jet engine, you're not alone. K-Means is one of those algorithms that feels like it should be fast — the concept is dead simple — but at scale, it eats memory and CPU time like nobody's business.
A new paper just hit arXiv called Flash-KMeans, and it's getting attention on Hacker News for good reason. It proposes an exact K-Means implementation that's dramatically faster and more memory-efficient than what most of us are using today. Not an approximation. Not a different algorithm. The same K-Means, just implemented smarter.
Let me break down why this matters and what you can actually do with it.
Why Standard K-Means Is Wasteful
The classic Lloyd's algorithm for K-Means does three things every iteration:
- Compute distances from every point to every centroid
- Assign each point to its nearest centroid
- Recompute centroids as the mean of assigned points
Step 1 is where the pain is. If you have n data points, k clusters, and d dimensions, you're computing n × k distances, each costing O(d) operations. That full distance matrix can get enormous.
Most implementations — including scikit-learn's — store intermediate results that balloon memory usage. When you're clustering millions of embeddings (say, 768-dimensional vectors from a transformer model), you can easily OOM before you get any results.
# The naive approach most of us start with
from sklearn.cluster import KMeans
import numpy as np
# 2M vectors, 768 dimensions — good luck with memory
data = np.random.randn(2_000_000, 768).astype(np.float32)
kmeans = KMeans(n_clusters=1000, n_init=1)
kmeans.fit(data) # hope you have 64GB+ of RAM
What Flash-KMeans Does Differently
The core insight behind Flash-KMeans is applying the same kind of thinking that made Flash Attention successful: restructure the computation to be cache-friendly and avoid materializing large intermediate matrices.
Here's the gist of the key techniques:
Tiled distance computation: Instead of computing the full
n × kdistance matrix at once, Flash-KMeans processes data in tiles that fit in CPU cache. This sounds simple but the implementation details matter a lot — the tile sizes need to match your hardware's cache hierarchy.Fused assignment and reduction: Rather than storing all distances and then finding minimums in a separate pass, Flash-KMeans fuses these operations. Compute a tile of distances, immediately update the assignments and running centroid sums, then discard the distances. Memory usage drops from
O(n × k)toO(tile_size × k).Exploiting the triangle inequality: By tracking upper and lower bounds on distances, many point-centroid distance calculations can be skipped entirely. This isn't new (Elkan's algorithm does this), but combining it with the tiled approach is where the novelty lies.
The result? The paper reports speedups of up to 5-16x over optimized baselines while using a fraction of the memory. And again — these are exact results. Same output as standard K-Means, just faster.
Why This Matters Right Now
We're in the middle of an embeddings explosion. Every app and its dog is generating vector embeddings for search, RAG pipelines, recommendations, you name it. Clustering those embeddings is a bread-and-butter operation for:
- Building approximate nearest neighbor indices (IVF-style)
- Analyzing user behavior patterns
- Product quantization for vector compression
- Deduplication and data cleaning
FAISS already uses optimized K-Means internally for building IVF indices, but even FAISS can struggle when your dataset gets into the hundreds of millions. Flash-KMeans-style optimizations could make a real difference there.
# What clustering large embedding datasets often looks like in practice
import faiss
import numpy as np
d = 768
ncentroids = 4096
data = np.load("embeddings.npy") # shape: (50_000_000, 768)
# FAISS K-Means is already faster than sklearn, but still hungry
kmeans = faiss.Kmeans(d, ncentroids, niter=20, gpu=False)
kmeans.train(data) # still takes a while on CPU
# Flash-KMeans-style tiling could help here
# The key idea: process in cache-friendly blocks
TILE_SIZE = 4096 # fits in L2 cache
for i in range(0, len(data), TILE_SIZE):
tile = data[i:i + TILE_SIZE]
# compute distances only for this tile
# update assignments immediately
# accumulate centroid sums without storing full distance matrix
The Broader Pattern: Algorithm-Level Optimization
What I find most interesting about Flash-KMeans isn't the specific technique — it's the pattern. We're seeing a wave of papers that take well-known algorithms and dramatically speed them up by respecting modern hardware realities.
Flash Attention did it for transformers. Flash-KMeans does it for clustering. The common thread is: stop treating RAM as infinite and flat. Cache hierarchy matters. Memory bandwidth is the bottleneck, not FLOPs.
This is a mindset shift that matters for application developers too. When I'm building data pipelines, I've started thinking more carefully about memory access patterns — even in Python. Things like:
- Processing data in chunks that fit in cache
- Avoiding unnecessary copies and intermediate arrays
- Using
float32instead offloat64when precision isn't critical - Picking libraries that are cache-aware under the hood
# Simple example: chunked processing vs. materializing everything
import numpy as np
def compute_distances_naive(data, centroids):
"""Materializes full distance matrix — memory hog"""
# This creates an (n, k) matrix all at once
diff = data[:, np.newaxis, :] - centroids[np.newaxis, :, :]
return np.sum(diff ** 2, axis=2) # O(n * k * d) memory for diff
def compute_assignments_tiled(data, centroids, tile_size=2048):
"""Tiled approach — constant memory overhead"""
n = len(data)
assignments = np.empty(n, dtype=np.int32)
for start in range(0, n, tile_size):
end = min(start + tile_size, n)
tile = data[start:end]
# Only materialize distances for this tile
dists = np.sum((tile[:, np.newaxis, :] - centroids[np.newaxis, :, :]) ** 2, axis=2)
assignments[start:end] = np.argmin(dists, axis=1)
return assignments
The tiled version uses way less peak memory and is often faster too because of better cache utilization. It's not as dramatic as the full Flash-KMeans approach (which also fuses operations and uses bound-skipping), but it captures the core idea.
Should You Care?
If you're clustering datasets under 100K points, honestly, sklearn is fine. Don't optimize what isn't slow.
But if you're working with large-scale embeddings — especially if you're building search infrastructure, doing vector quantization, or running clustering as part of a larger pipeline — keep an eye on Flash-KMeans. I expect we'll see implementations show up in libraries like FAISS and scikit-learn-extra before long.
In the meantime, the tiling approach is something you can apply yourself today. I've been using chunked processing for a while when building analytics pipelines — even when tracking simple metrics with tools like Umami or Plausible for privacy-friendly web analytics, the principle of processing data in efficient chunks rather than loading everything into memory applies everywhere.
The bigger takeaway: the era of "just throw more RAM at it" is ending. Papers like Flash-KMeans show there's still a lot of performance left on the table in algorithms we thought were fully optimized. The hardware-aware algorithm design pattern is going to keep producing wins across the board.
I'll be watching for a reference implementation to drop. When it does, I'll benchmark it against FAISS K-Means on some real embedding datasets and share the results.
Top comments (0)