K-Means Clustering: Complete Guide
K-Means from first principles — the algorithm and convergence proof, K-Means++ initialization, choosing K with elbow/silhouette/gap statistics, Mini-Batch variants, distributed K-Means, and when the algorithm breaks down. 10 hard interview questions with detailed answers.
The Core Algorithm
K-Means partitions n data points into k non-overlapping clusters by minimizing the within-cluster sum of squares (WCSS): J = Σⱼ Σₓ∈Cⱼ ||x − μⱼ||². The algorithm solves this via coordinate descent: hold centroids fixed → optimize assignments; hold assignments fixed → optimize centroids. Each step is analytically solvable (nearest centroid; centroid = mean), making the algorithm simple and fast. The catch: the problem is NP-hard globally (finding the true optimum requires checking all possible assignments), so the algorithm finds a local minimum. Convergence is always guaranteed; quality depends on initialization.
K-Means Algorithm — Complete Flowchart
Convergence Proof — Why K-Means Always Terminates
Claim: WCSS is non-increasing
Let J(t) = WCSS after iteration t. We prove J(t+1) ≤ J(t) by showing each sub-step (assign then update) cannot increase J.
Assignment step doesn't increase J
Each point is reassigned to its nearest centroid. If the nearest centroid changes, the squared distance to the new centroid ≤ squared distance to the old centroid → J decreases or stays the same. Formally: ||xᵢ − μ*||² ≤ ||xᵢ − μold||² where μ* is the new assignment.
Update step doesn't increase J
The centroid μⱼ = mean(Cⱼ) is the unique minimizer of Σₓ∈Cⱼ ||x − c||² over all c. Proof: ∂/∂c Σ||x−c||² = −2Σ(x−c) = 0 → c = mean. Any other centroid placement gives higher WCSS.
Finite termination
Since J is non-increasing and bounded below by 0, and there are finitely many distinct assignment configurations (kⁿ), the algorithm must terminate. In practice, tol-based early stopping triggers much sooner — typically after 10–50 iterations.
Local minimum caveat
The algorithm finds a stationary point of J, not necessarily the global minimum. Two bad initialization patterns: (1) Multiple centroids in one true cluster → one cluster splits, another remains undetected. (2) A centroid starts between two clusters → captures both → neither is well-represented. Solution: run K-Means multiple times (n_init=10–20) or use K-Means++ initialization.
K-Means from Scratch — Fully Vectorized
import numpy as np
# ── Core assignment step ────────────────────────────────────────────────
def assign_clusters(X, centroids):
"""
Vectorized: computes (n, k) distance matrix in one shot.
X: (n, d), centroids: (k, d) → labels: (n,)
"""
# Broadcasting: (n, 1, d) - (1, k, d) → (n, k, d) → squared (n, k)
diffs = X[:, np.newaxis, :] - centroids[np.newaxis, :, :]
sq_dists = (diffs ** 2).sum(axis=2) # (n, k)
return np.argmin(sq_dists, axis=1) # (n,)
# ── Core update step ───────────────────────────────────────────────────
def update_centroids(X, labels, k, old_centroids):
"""Recompute centroids; keep old centroid on empty cluster."""
new = np.empty_like(old_centroids)
for j in range(k):
mask = labels == j
new[j] = X[mask].mean(axis=0) if mask.any() else old_centroids[j]
return new
# ── K-Means++ initialization ───────────────────────────────────────────
def kmeans_pp_init(X, k, rng=None):
"""
O(nkd) initialization that guarantees O(log k) approximation.
Strategy: spread centroids apart using D² sampling.
"""
rng = rng or np.random.default_rng(42)
n = len(X)
centroids = [X[rng.integers(n)]] # first centroid: random
for _ in range(k - 1):
# D²(x) = min squared distance to any already-chosen centroid
dists = np.min(
[np.sum((X - c) ** 2, axis=1) for c in centroids], axis=0
)
probs = dists / dists.sum() # normalize to probability distribution
# Sample proportional to D² — far-away points are more likely chosen
centroids.append(X[rng.choice(n, p=probs)])
return np.array(centroids, dtype=float)
# ── Full K-Means ───────────────────────────────────────────────────────
class KMeans:
def __init__(self, k=3, max_iters=300, tol=1e-4, init="kmeans++", n_init=10):
self.k = k
self.max_iters = max_iters
self.tol = tol
self.init = init
self.n_init = n_init # number of random restarts
def _run_once(self, X, rng):
centroids = (kmeans_pp_init(X, self.k, rng) if self.init == "kmeans++"
else X[rng.choice(len(X), self.k, replace=False)].astype(float))
for _ in range(self.max_iters):
labels = assign_clusters(X, centroids)
new_centroids = update_centroids(X, labels, self.k, centroids)
if np.linalg.norm(new_centroids - centroids) < self.tol:
break
centroids = new_centroids
inertia = sum(
np.sum((X[labels == j] - centroids[j]) ** 2)
for j in range(self.k)
)
return labels, centroids, inertia
def fit(self, X):
X = np.array(X, dtype=float)
best_inertia = np.inf
rng = np.random.default_rng(42)
for _ in range(self.n_init):
labels, centroids, inertia = self._run_once(X, rng)
if inertia < best_inertia:
best_inertia = inertia
self.labels_ = labels
self.cluster_centers_ = centroids
self.inertia_ = inertia
return self
K-Means++ — Why D² Sampling Works
Standard random initialization fails because centroids can cluster together, leaving some true clusters with no nearby centroid. K-Means++ fixes this by making 'far from existing centroids' mean 'more likely to be chosen'. The key insight: if we sample with probability proportional to D(x)² (squared distance to nearest centroid), we're essentially spreading centroids across the high-density regions of the data.
Theoretical guarantee: Arthur & Vassilvitskii (2007) proved that K-Means++ gives expected WCSS within O(log k) of the optimal clustering. In practice, this means 2–5× better solutions and 2–3× fewer iterations than random initialization on typical datasets.
Cost: K-Means++ initialization takes O(nkd) — a constant factor overhead. For k=10 clusters this is 10 additional distance computations per sample, almost always worth it.
K-Means++ vs Random Init — Why Spread Matters
Choosing K — Three Methods + When to Use Each
Elbow Method (Inertia vs K)
Run K-Means for k=1..15. Plot WCSS against k. The curve bends at the 'elbow' — where adding more clusters gives diminishing improvement. Compute numerically: find k where the second derivative of inertia is maximized. Limitation: the elbow is often ambiguous on real data (smooth curve, no clear bend). Best for: exploratory analysis, verifying that your intuitive k makes sense.
Silhouette Score
For each point xᵢ: a(i) = mean distance to other points in same cluster (cohesion). b(i) = mean distance to points in the nearest other cluster (separation). Silhouette(i) = (b(i) - a(i)) / max(a(i), b(i)). Ranges from -1 (wrong cluster) to +1 (perfect). Average silhouette score peaks at optimal k. More objective than elbow. Limitation: O(n²) computation — use mini-batch silhouette for large datasets.
Gap Statistic (Tibshirani et al.)
Gap(k) = E[log(WCSS_random(k))] - log(WCSS_data(k)). Compares your clustering to k-means on a random uniform distribution. The optimal k is where Gap(k) ≥ Gap(k+1) - std(Gap(k+1)). Most statistically principled but requires generating multiple random datasets (B=20 typical) making it B× more expensive. Best for: academic papers, when you need a defensible statistical justification for k.
Business / Domain Knowledge (Often Best)
If clustering customers: you want 5–8 segments that product teams can act on. If clustering documents: topic models suggest 20–100 topics. If clustering for dimensionality reduction: pick k based on explained variance. Statistical methods validate k; they rarely discover it from scratch in meaningful applications.
Automatic K Selection
import numpy as np
from sklearn.cluster import KMeans, MiniBatchKMeans
from sklearn.metrics import silhouette_score
def find_optimal_k(X, k_max=15, method="both", sample_size=5000):
"""
Evaluate k=2..k_max using elbow + silhouette.
Uses MiniBatchKMeans for speed on large datasets.
"""
n = len(X)
Cls = MiniBatchKMeans if n > 50_000 else KMeans
inertias, sil_scores = [], []
sample_idx = np.random.choice(n, min(sample_size, n), replace=False)
X_sample = X[sample_idx]
for k in range(2, k_max + 1):
km = Cls(n_clusters=k, init="k-means++", n_init=5, random_state=42)
labels = km.fit_predict(X)
inertias.append(km.inertia_)
# Silhouette on subsample (O(n²) → O(sample²))
sil_labels = km.predict(X_sample)
if len(set(sil_labels)) > 1: # need at least 2 clusters
sil_scores.append(silhouette_score(X_sample, sil_labels))
else:
sil_scores.append(-1)
# Elbow: find k maximizing second derivative of inertia
deltas = np.diff(inertias)
second_deriv = np.diff(deltas)
elbow_k = 2 + np.argmax(second_deriv) + 1 if len(second_deriv) > 0 else 2
# Silhouette: pick k with highest average score
sil_k = 2 + np.argmax(sil_scores)
best_sil = max(sil_scores)
return {
"elbow_k": elbow_k,
"silhouette_k": sil_k,
"best_silhouette": round(best_sil, 3),
"inertias": inertias,
"silhouette_scores": sil_scores,
"recommendation": sil_k if best_sil > 0.5 else elbow_k
}
K-Means Variants — When to Use Each
| Variant | Key Change | Pros | Cons | When to Use |
|---|---|---|---|---|
| K-Means | Standard EM | Simple, exact | O(nkd) per iter, sensitive to init | n < 100K |
| K-Means++ | D² initialization | Better local optima, O(log k) guarantee | Slight init overhead | Default over random init always |
| Mini-Batch K-Means | Random batch per iter, online update | O(1) memory per iter, 3–10× faster | Slightly worse WCSS (~1%) | n > 100K or streaming data |
| K-Medoids (PAM) | Centers must be actual data points | Robust to outliers (uses L1 implicitly) | O(kn²) — very slow | Small data with meaningful outliers |
| K-Means‖ (parallel++) | Multiple candidates per round, fewer passes | Distributed-friendly init | More complex | Spark / MapReduce K-Means |
| Bisecting K-Means | Start with k=1, repeatedly split largest | More stable than random, fewer restarts needed | Not optimizing global J | Hierarchical exploration needed |
When K-Means Fails — Know the 4 Assumptions
K-Means assumes: (1) SPHERICAL clusters — fails on elongated, banana-shaped, or ring-shaped clusters. DBSCAN or Spectral Clustering handle these. (2) EQUAL VARIANCE — fails when clusters have very different spreads. GMM with unconstrained covariances handles this. (3) EQUAL SIZE — large cluster may split while small cluster disappears. Set min_cluster_size or use K-Medoids. (4) EUCLIDEAN DISTANCE is meaningful — fails for: high-dimensional sparse data (cosine similarity works better), categorical features (use K-Modes or Gower distance), mixed types (normalize + apply appropriate distance). Practical check: after clustering, visualize with t-SNE/UMAP. If clusters overlap significantly in the visualization, K-Means is probably wrong for your data.
K-Means vs DBSCAN vs GMM — Decision Guide
| Algorithm | Cluster Shape | Outliers | Soft Assignments | Scalability | Best For |
|---|---|---|---|---|---|
| K-Means | Spherical | Sensitive (assigns to nearest) | No | O(nkd) — excellent | Large structured data, known k |
| DBSCAN | Arbitrary density-connected | Labeled as noise class | No | O(n log n) with kd-tree | Geographic data, arbitrary shapes, unknown k |
| GMM | Elliptical Gaussian | Moderate (tails exist) | Yes (probabilities) | O(nkd²) — slower | When you need P(cluster|x), confidence scores |
| Agglomerative | Arbitrary (dendrogram) | Moderate | No | O(n² log n) — slow | Small data, exploring hierarchical structure |
| Spectral | Non-convex (graph-based) | Sensitive | No | O(n³) — very slow | Ring/manifold clusters, small to medium n |
Distributed K-Means (MapReduce Pattern)
import numpy as np
def mapreduce_kmeans_iteration(data_shards, centroids, k):
"""
One MapReduce iteration of K-Means.
Map phase: each shard independently assigns points and
computes per-cluster partial sums.
Reduce phase: aggregate partial sums → new centroids.
Key guarantee: produces IDENTICAL result to sequential K-Means
because addition is associative and commutative.
"""
# MAP: each shard independently
def map_shard(X_shard, centroids):
labels = np.argmin(
np.sum((X_shard[:, None] - centroids[None]) ** 2, axis=2),
axis=1
)
# Partial sum: (sum_of_points, count) per cluster
partial_sums = np.zeros((k, centroids.shape[1]))
counts = np.zeros(k, dtype=int)
for j in range(k):
mask = labels == j
if mask.any():
partial_sums[j] = X_shard[mask].sum(axis=0)
counts[j] = mask.sum()
return partial_sums, counts
# Simulate parallel map
all_partial_sums = np.zeros((k, centroids.shape[1]))
all_counts = np.zeros(k, dtype=int)
for shard in data_shards:
partial_sums, counts = map_shard(shard, centroids)
all_partial_sums += partial_sums # REDUCE: just sum
all_counts += counts
# Final centroid = sum / count (same as mean of all points)
new_centroids = np.where(
all_counts[:, None] > 0,
all_partial_sums / all_counts[:, None],
centroids # keep old if empty
)
return new_centroids
# Notes on distributed implementation:
# - K-Means‖ (parallel++) is the distributed init: in O(log n) rounds,
# sample O(k) candidate centroids, then reduce to k via weighted K-Means.
# - Mini-Batch K-Means is an online approximation that's trivially distributed.
# - The map phase is embarrassingly parallel — no communication needed.
# - Only the reduce (centroid aggregation) requires cross-node communication.
# - This makes K-Means one of the most MapReduce-friendly ML algorithms.
Interview Scenario: Cluster 10M User Profiles
Setup: 10M users, 500 features (behavior + demographics), goal = 20 segments for personalization.
Step 1 — Dimensionality reduction: PCA to 50 components (retain 95% variance) or autoencoder for non-linear structure. Reduces computation by 10×.
Step 2 — Algorithm choice: Mini-Batch K-Means, batch_size=10000. Full K-Means needs O(10M × 20 × 50) per iter = 10B FLOPs — acceptable on GPU but memory is 10M × 50 × 8 bytes = 4GB per batch.
Step 3 — Initialization: K-Means‖ (parallel++) on a 100K sample. Then warm-start Mini-Batch with these centroids.
Step 4 — Choosing k: silhouette score on 50K holdout sample for k=10, 15, 20, 25, 30. Also evaluate downstream metric (CTR, conversion) for top 3 candidates.
Step 5 — Stability: re-cluster monthly. Track cluster drift using centroid shift + overlap of nearest-neighbor membership. Map old clusters to new via Hungarian algorithm on centroid distance matrix.
Interview Questions
Click to reveal answersSign in to take the Quiz
This topic has 16 quiz questions with instant feedback and detailed explanations. Sign in to unlock quizzes.
Sign in to take quiz →Start Solving
You've covered the theory. Now implement it from scratch and run your solution against hidden test cases.
Open Coding Problem →