Skip to main content
Machine Learning·Intermediate

K-Means Clustering: Complete Guide

K-Means from first principles — the algorithm and convergence proof, K-Means++ initialization, choosing K with elbow/silhouette/gap statistics, Mini-Batch variants, distributed K-Means, and when the algorithm breaks down. 10 hard interview questions with detailed answers.

55 min read 13 sections 10 interview questions
K-MeansClusteringUnsupervised LearningK-Means++Elbow MethodSilhouetteDBSCANGMM

The Core Algorithm

K-Means partitions n data points into k non-overlapping clusters by minimizing the within-cluster sum of squares (WCSS): J = Σⱼ Σₓ∈Cⱼ ||x − μⱼ||². The algorithm solves this via coordinate descent: hold centroids fixed → optimize assignments; hold assignments fixed → optimize centroids. Each step is analytically solvable (nearest centroid; centroid = mean), making the algorithm simple and fast. The catch: the problem is NP-hard globally (finding the true optimum requires checking all possible assignments), so the algorithm finds a local minimum. Convergence is always guaranteed; quality depends on initialization.

K-Means Algorithm — Complete Flowchart

Rendering diagram...

Convergence Proof — Why K-Means Always Terminates

01

Claim: WCSS is non-increasing

Let J(t) = WCSS after iteration t. We prove J(t+1) ≤ J(t) by showing each sub-step (assign then update) cannot increase J.

02

Assignment step doesn't increase J

Each point is reassigned to its nearest centroid. If the nearest centroid changes, the squared distance to the new centroid ≤ squared distance to the old centroid → J decreases or stays the same. Formally: ||xᵢ − μ*||² ≤ ||xᵢ − μold||² where μ* is the new assignment.

03

Update step doesn't increase J

The centroid μⱼ = mean(Cⱼ) is the unique minimizer of Σₓ∈Cⱼ ||x − c||² over all c. Proof: ∂/∂c Σ||x−c||² = −2Σ(x−c) = 0 → c = mean. Any other centroid placement gives higher WCSS.

04

Finite termination

Since J is non-increasing and bounded below by 0, and there are finitely many distinct assignment configurations (kⁿ), the algorithm must terminate. In practice, tol-based early stopping triggers much sooner — typically after 10–50 iterations.

05

Local minimum caveat

The algorithm finds a stationary point of J, not necessarily the global minimum. Two bad initialization patterns: (1) Multiple centroids in one true cluster → one cluster splits, another remains undetected. (2) A centroid starts between two clusters → captures both → neither is well-represented. Solution: run K-Means multiple times (n_init=10–20) or use K-Means++ initialization.

K-Means from Scratch — Fully Vectorized

pythonkmeans_scratch.py
import numpy as np

# ── Core assignment step ────────────────────────────────────────────────

def assign_clusters(X, centroids):
    """
    Vectorized: computes (n, k) distance matrix in one shot.
    X: (n, d), centroids: (k, d) → labels: (n,)
    """
    # Broadcasting: (n, 1, d) - (1, k, d) → (n, k, d) → squared (n, k)
    diffs = X[:, np.newaxis, :] - centroids[np.newaxis, :, :]
    sq_dists = (diffs ** 2).sum(axis=2)  # (n, k)
    return np.argmin(sq_dists, axis=1)    # (n,)

# ── Core update step ───────────────────────────────────────────────────

def update_centroids(X, labels, k, old_centroids):
    """Recompute centroids; keep old centroid on empty cluster."""
    new = np.empty_like(old_centroids)
    for j in range(k):
        mask = labels == j
        new[j] = X[mask].mean(axis=0) if mask.any() else old_centroids[j]
    return new

# ── K-Means++ initialization ───────────────────────────────────────────

def kmeans_pp_init(X, k, rng=None):
    """
    O(nkd) initialization that guarantees O(log k) approximation.
    Strategy: spread centroids apart using D² sampling.
    """
    rng = rng or np.random.default_rng(42)
    n = len(X)
    centroids = [X[rng.integers(n)]]  # first centroid: random

    for _ in range(k - 1):
        # D²(x) = min squared distance to any already-chosen centroid
        dists = np.min(
            [np.sum((X - c) ** 2, axis=1) for c in centroids], axis=0
        )
        probs = dists / dists.sum()  # normalize to probability distribution
        # Sample proportional to D² — far-away points are more likely chosen
        centroids.append(X[rng.choice(n, p=probs)])

    return np.array(centroids, dtype=float)

# ── Full K-Means ───────────────────────────────────────────────────────

class KMeans:
    def __init__(self, k=3, max_iters=300, tol=1e-4, init="kmeans++", n_init=10):
        self.k = k
        self.max_iters = max_iters
        self.tol = tol
        self.init = init
        self.n_init = n_init  # number of random restarts

    def _run_once(self, X, rng):
        centroids = (kmeans_pp_init(X, self.k, rng) if self.init == "kmeans++"
                     else X[rng.choice(len(X), self.k, replace=False)].astype(float))

        for _ in range(self.max_iters):
            labels = assign_clusters(X, centroids)
            new_centroids = update_centroids(X, labels, self.k, centroids)
            if np.linalg.norm(new_centroids - centroids) < self.tol:
                break
            centroids = new_centroids

        inertia = sum(
            np.sum((X[labels == j] - centroids[j]) ** 2)
            for j in range(self.k)
        )
        return labels, centroids, inertia

    def fit(self, X):
        X = np.array(X, dtype=float)
        best_inertia = np.inf
        rng = np.random.default_rng(42)

        for _ in range(self.n_init):
            labels, centroids, inertia = self._run_once(X, rng)
            if inertia < best_inertia:
                best_inertia = inertia
                self.labels_ = labels
                self.cluster_centers_ = centroids
                self.inertia_ = inertia

        return self

K-Means++ — Why D² Sampling Works

Standard random initialization fails because centroids can cluster together, leaving some true clusters with no nearby centroid. K-Means++ fixes this by making 'far from existing centroids' mean 'more likely to be chosen'. The key insight: if we sample with probability proportional to D(x)² (squared distance to nearest centroid), we're essentially spreading centroids across the high-density regions of the data.

Theoretical guarantee: Arthur & Vassilvitskii (2007) proved that K-Means++ gives expected WCSS within O(log k) of the optimal clustering. In practice, this means 2–5× better solutions and 2–3× fewer iterations than random initialization on typical datasets.

Cost: K-Means++ initialization takes O(nkd) — a constant factor overhead. For k=10 clusters this is 10 additional distance computations per sample, almost always worth it.

K-Means++ vs Random Init — Why Spread Matters

Rendering diagram...

Choosing K — Three Methods + When to Use Each

01

Elbow Method (Inertia vs K)

Run K-Means for k=1..15. Plot WCSS against k. The curve bends at the 'elbow' — where adding more clusters gives diminishing improvement. Compute numerically: find k where the second derivative of inertia is maximized. Limitation: the elbow is often ambiguous on real data (smooth curve, no clear bend). Best for: exploratory analysis, verifying that your intuitive k makes sense.

02

Silhouette Score

For each point xᵢ: a(i) = mean distance to other points in same cluster (cohesion). b(i) = mean distance to points in the nearest other cluster (separation). Silhouette(i) = (b(i) - a(i)) / max(a(i), b(i)). Ranges from -1 (wrong cluster) to +1 (perfect). Average silhouette score peaks at optimal k. More objective than elbow. Limitation: O(n²) computation — use mini-batch silhouette for large datasets.

03

Gap Statistic (Tibshirani et al.)

Gap(k) = E[log(WCSS_random(k))] - log(WCSS_data(k)). Compares your clustering to k-means on a random uniform distribution. The optimal k is where Gap(k) ≥ Gap(k+1) - std(Gap(k+1)). Most statistically principled but requires generating multiple random datasets (B=20 typical) making it B× more expensive. Best for: academic papers, when you need a defensible statistical justification for k.

04

Business / Domain Knowledge (Often Best)

If clustering customers: you want 5–8 segments that product teams can act on. If clustering documents: topic models suggest 20–100 topics. If clustering for dimensionality reduction: pick k based on explained variance. Statistical methods validate k; they rarely discover it from scratch in meaningful applications.

Automatic K Selection

pythonchoose_k.py
import numpy as np
from sklearn.cluster import KMeans, MiniBatchKMeans
from sklearn.metrics import silhouette_score

def find_optimal_k(X, k_max=15, method="both", sample_size=5000):
    """
    Evaluate k=2..k_max using elbow + silhouette.
    Uses MiniBatchKMeans for speed on large datasets.
    """
    n = len(X)
    Cls = MiniBatchKMeans if n > 50_000 else KMeans

    inertias, sil_scores = [], []
    sample_idx = np.random.choice(n, min(sample_size, n), replace=False)
    X_sample = X[sample_idx]

    for k in range(2, k_max + 1):
        km = Cls(n_clusters=k, init="k-means++", n_init=5, random_state=42)
        labels = km.fit_predict(X)
        inertias.append(km.inertia_)

        # Silhouette on subsample (O(n²) → O(sample²))
        sil_labels = km.predict(X_sample)
        if len(set(sil_labels)) > 1:  # need at least 2 clusters
            sil_scores.append(silhouette_score(X_sample, sil_labels))
        else:
            sil_scores.append(-1)

    # Elbow: find k maximizing second derivative of inertia
    deltas = np.diff(inertias)
    second_deriv = np.diff(deltas)
    elbow_k = 2 + np.argmax(second_deriv) + 1 if len(second_deriv) > 0 else 2

    # Silhouette: pick k with highest average score
    sil_k = 2 + np.argmax(sil_scores)
    best_sil = max(sil_scores)

    return {
        "elbow_k": elbow_k,
        "silhouette_k": sil_k,
        "best_silhouette": round(best_sil, 3),
        "inertias": inertias,
        "silhouette_scores": sil_scores,
        "recommendation": sil_k if best_sil > 0.5 else elbow_k
    }

K-Means Variants — When to Use Each

VariantKey ChangeProsConsWhen to Use
K-MeansStandard EMSimple, exactO(nkd) per iter, sensitive to initn < 100K
K-Means++D² initializationBetter local optima, O(log k) guaranteeSlight init overheadDefault over random init always
Mini-Batch K-MeansRandom batch per iter, online updateO(1) memory per iter, 3–10× fasterSlightly worse WCSS (~1%)n > 100K or streaming data
K-Medoids (PAM)Centers must be actual data pointsRobust to outliers (uses L1 implicitly)O(kn²) — very slowSmall data with meaningful outliers
K-Means‖ (parallel++)Multiple candidates per round, fewer passesDistributed-friendly initMore complexSpark / MapReduce K-Means
Bisecting K-MeansStart with k=1, repeatedly split largestMore stable than random, fewer restarts neededNot optimizing global JHierarchical exploration needed
⚠ WARNING

When K-Means Fails — Know the 4 Assumptions

K-Means assumes: (1) SPHERICAL clusters — fails on elongated, banana-shaped, or ring-shaped clusters. DBSCAN or Spectral Clustering handle these. (2) EQUAL VARIANCE — fails when clusters have very different spreads. GMM with unconstrained covariances handles this. (3) EQUAL SIZE — large cluster may split while small cluster disappears. Set min_cluster_size or use K-Medoids. (4) EUCLIDEAN DISTANCE is meaningful — fails for: high-dimensional sparse data (cosine similarity works better), categorical features (use K-Modes or Gower distance), mixed types (normalize + apply appropriate distance). Practical check: after clustering, visualize with t-SNE/UMAP. If clusters overlap significantly in the visualization, K-Means is probably wrong for your data.

K-Means vs DBSCAN vs GMM — Decision Guide

AlgorithmCluster ShapeOutliersSoft AssignmentsScalabilityBest For
K-MeansSphericalSensitive (assigns to nearest)NoO(nkd) — excellentLarge structured data, known k
DBSCANArbitrary density-connectedLabeled as noise classNoO(n log n) with kd-treeGeographic data, arbitrary shapes, unknown k
GMMElliptical GaussianModerate (tails exist)Yes (probabilities)O(nkd²) — slowerWhen you need P(cluster|x), confidence scores
AgglomerativeArbitrary (dendrogram)ModerateNoO(n² log n) — slowSmall data, exploring hierarchical structure
SpectralNon-convex (graph-based)SensitiveNoO(n³) — very slowRing/manifold clusters, small to medium n

Distributed K-Means (MapReduce Pattern)

pythondistributed_kmeans.py
import numpy as np

def mapreduce_kmeans_iteration(data_shards, centroids, k):
    """
    One MapReduce iteration of K-Means.
    
    Map phase: each shard independently assigns points and
               computes per-cluster partial sums.
    Reduce phase: aggregate partial sums → new centroids.
    
    Key guarantee: produces IDENTICAL result to sequential K-Means
    because addition is associative and commutative.
    """

    # MAP: each shard independently
    def map_shard(X_shard, centroids):
        labels = np.argmin(
            np.sum((X_shard[:, None] - centroids[None]) ** 2, axis=2),
            axis=1
        )
        # Partial sum: (sum_of_points, count) per cluster
        partial_sums = np.zeros((k, centroids.shape[1]))
        counts = np.zeros(k, dtype=int)
        for j in range(k):
            mask = labels == j
            if mask.any():
                partial_sums[j] = X_shard[mask].sum(axis=0)
                counts[j] = mask.sum()
        return partial_sums, counts

    # Simulate parallel map
    all_partial_sums = np.zeros((k, centroids.shape[1]))
    all_counts = np.zeros(k, dtype=int)

    for shard in data_shards:
        partial_sums, counts = map_shard(shard, centroids)
        all_partial_sums += partial_sums  # REDUCE: just sum
        all_counts += counts

    # Final centroid = sum / count (same as mean of all points)
    new_centroids = np.where(
        all_counts[:, None] > 0,
        all_partial_sums / all_counts[:, None],
        centroids  # keep old if empty
    )
    return new_centroids

# Notes on distributed implementation:
# - K-Means‖ (parallel++) is the distributed init: in O(log n) rounds,
#   sample O(k) candidate centroids, then reduce to k via weighted K-Means.
# - Mini-Batch K-Means is an online approximation that's trivially distributed.
# - The map phase is embarrassingly parallel — no communication needed.
# - Only the reduce (centroid aggregation) requires cross-node communication.
# - This makes K-Means one of the most MapReduce-friendly ML algorithms.
EXAMPLE

Interview Scenario: Cluster 10M User Profiles

Setup: 10M users, 500 features (behavior + demographics), goal = 20 segments for personalization.

Step 1 — Dimensionality reduction: PCA to 50 components (retain 95% variance) or autoencoder for non-linear structure. Reduces computation by 10×.

Step 2 — Algorithm choice: Mini-Batch K-Means, batch_size=10000. Full K-Means needs O(10M × 20 × 50) per iter = 10B FLOPs — acceptable on GPU but memory is 10M × 50 × 8 bytes = 4GB per batch.

Step 3 — Initialization: K-Means‖ (parallel++) on a 100K sample. Then warm-start Mini-Batch with these centroids.

Step 4 — Choosing k: silhouette score on 50K holdout sample for k=10, 15, 20, 25, 30. Also evaluate downstream metric (CTR, conversion) for top 3 candidates.

Step 5 — Stability: re-cluster monthly. Track cluster drift using centroid shift + overlap of nearest-neighbor membership. Map old clusters to new via Hungarian algorithm on centroid distance matrix.

Interview Questions

Click to reveal answers
Test your knowledge

Sign in to take the Quiz

This topic has 16 quiz questions with instant feedback and detailed explanations. Sign in to unlock quizzes.

Sign in to take quiz →
Ready to put it into practice?

Start Solving

You've covered the theory. Now implement it from scratch and run your solution against hidden test cases.

Open Coding Problem →