9.5 K Closest Points to Origin

9.5.1 Problem Metadata

9.5.2 Description

Given an array of points where points[i] = [xi, yi] represents a point on the X-Y plane and an integer k, return the k closest points to the origin (0, 0).

The distance between two points on the X-Y plane is the Euclidean distance (i.e., \(\sqrt{(x_1 - x_2)^2 + (y_1 - y_2)^2}\)).

You may return the answer in any order. The answer is guaranteed to be unique (except for the order that it is in).

9.5.3 Examples

Example 1:

Input: points = [[1,3],[-2,2]], k = 1
Output: [[-2,2]]
Explanation:
The distance between (1, 3) and the origin is sqrt(10).
The distance between (-2, 2) and the origin is sqrt(8).
Since sqrt(8) < sqrt(10), (-2, 2) is closer to the origin.
We only want the closest k = 1 points from the origin, so the answer is just [[-2,2]].

Example 2:

Input: points = [[3,3],[5,-1],[-2,4]], k = 2
Output: [[3,3],[-2,4]]
Explanation: The answer [[-2,4],[3,3]] would also be accepted.

9.5.4 Constraints

  • \(1 \le k \le \text{points.length} \le 10^4\)
  • \(-10^4 \le x_i, y_i \le 10^4\)

9.5.5 Solution - Max Heap

9.5.5.1 Walkthrough

This solution uses a max heap to efficiently find the k closest points without fully sorting the array.

Core Strategy:

Instead of sorting all n points, maintain a max heap of size k that always contains the k closest points seen so far. The heap is ordered by distance (max at top), allowing us to quickly discard farther points.

Key Insight: Why Max Heap, Not Min Heap?

  • Max heap keeps the farthest of the k closest points at the top
  • When we encounter a new point:
    • If heap size \(<\) k: add it (we need more points)
    • If heap size \(=\) k and new point is closer than heap top: remove top (farthest), add new point
    • Otherwise: skip it (all k current points are closer)
  • This maintains exactly k closest points with O(log k) operations per point

Distance Optimization:

Since we only need to compare distances (not compute actual values), we use squared distance to avoid expensive sqrt() operations: - Distance: \(\sqrt{x^2 + y^2}\) - Squared distance: \(x^2 + y^2\) (sufficient for comparisons since sqrt is monotonic)

Step-by-step execution for points = [[1,3],[-2,2],[2,2]], k = 2:

Distances (squared):
- (1,3):   1² + 3² = 10
- (-2,2):  4 + 4 = 8
- (2,2):   4 + 4 = 8

Processing:
1. Add (1,3) dist=10:   Heap = [(1,3):10]                    size=1
2. Add (-2,2) dist=8:   Heap = [(1,3):10, (-2,2):8]          size=2 (reached k)
3. Process (2,2) dist=8:
   - Heap top = (1,3):10 > 8
   - Remove (1,3):10
   - Add (2,2):8
   - Heap = [(-2,2):8, (2,2):8]                              size=2

Result: [[-2,2], [2,2]] or [[2,2], [-2,2]] (order doesn't matter)

Why This Works:

  • After processing all points, the heap contains exactly the k smallest distances
  • Max heap property ensures we can always access and remove the farthest point in O(log k)
  • Total operations: n insertions × O(log k) = O(n log k)

9.5.5.2 Analysis

  • Time Complexity: O(n log k)
    • Process each of n points: O(n)
    • Each point may trigger heap insert/remove: O(log k)
    • Building result array: O(k)
    • Total: O(n log k) - much better than O(n log n) when k << n
    • Example: If n=10,000 and k=10, this is ~133,000 operations vs ~133,000,000 for full sort
  • Space Complexity: O(k)
    • Max heap stores at most k points
    • Result array: O(k) (required output)
    • This is optimal - we only need to remember k points

Comparison with Full Sort: - When k is small (k << n): Max heap wins significantly - k=1, n=10,000: O(10,000 × 1) vs O(10,000 × 13) = 13× faster - When k \(\approx\) n: Both approaches similar (O(n log n)) - Space: Heap uses O(k) vs O(n) for sorting

9.5.5.3 Implementation Steps

Setup: 1. Create helper class XYPoint to store: - Original coordinates (x, y) - Precomputed squared distance: distSq = x² + y² 2. Create max heap ordered by distSq (descending)

Main Algorithm: 3. For each point in input array: 1. Wrap point as XYPoint object 2. Add to max heap 3. If heap size exceeds k: - Remove heap top (the farthest point) - This maintains exactly k closest points

Build Result: 4. Extract all k points from heap into result array - Poll heap k times - Convert each XYPoint back to [x, y] format 5. Return result array

Note: Order doesn’t matter per problem constraints, so heap extraction order is acceptable.

9.5.5.4 Code - Java

class Solution {
    private static class XYPoint {
        public final int x;
        public final int y;
        public final int distSq;

        public XYPoint(int xi, int yi) {
            this.x = xi;
            this.y = yi;
            // Use squared distance to avoid expensive sqrt()
            this.distSq = x * x + y * y;
        }
    }

    public int[][] kClosest(int[][] points, int k) {
        // Max heap ordered by distance (largest distance at top)
        PriorityQueue<XYPoint> maxHeap = new PriorityQueue<>(
            Comparator.comparingInt((XYPoint point) -> point.distSq).reversed()
        );

        // Process each point
        for (int[] point : points) {
            maxHeap.add(new XYPoint(point[0], point[1]));

            // Keep only k closest points
            if (maxHeap.size() > k) {
                maxHeap.poll(); // Remove farthest point
            }
        }

        // Build result array from heap
        int[][] result = new int[k][2];
        for (int i = 0; i < k; i++) {
            XYPoint myPoint = maxHeap.poll();
            result[i][0] = myPoint.x;
            result[i][1] = myPoint.y;
        }

        return result;
    }
}

9.5.6 Solution - Sort + Stream API

9.5.6.1 Walkthrough

This solution uses Java’s Stream API to sort all points by distance and select the first k.

Approach: 1. Convert all points to XYPoint objects with precomputed squared distances 2. Stream the list and sort by distSq (ascending) 3. Limit to first k elements 4. Collect and convert back to 2D array

Trade-offs: - Simpler, more readable code - Less efficient: sorts all n points even when k is small - Good for interviews if code clarity is prioritized over optimal complexity

9.5.6.2 Analysis

  • Time Complexity: O(n log n)
    • Creating list: O(n)
    • Sorting all points: O(n log n)
    • Taking first k: O(k)
    • Total dominated by sorting: O(n log n)
  • Space Complexity: O(n)
    • List stores all n points
    • Stream intermediate operations: O(n)
    • Result array: O(k)

When to Use This: - Code readability is prioritized - k is close to n (no benefit from heap optimization) - Quick prototyping or interviews with relaxed time constraints

9.5.6.3 Code - Java

class Solution {
    private static class XYPoint {
        public final int x;
        public final int y;
        public final int distSq;

        public XYPoint(int xi, int yi) {
            this.x = xi;
            this.y = yi;
            this.distSq = x * x + y * y;
        }
    }

    public int[][] kClosest(int[][] points, int k) {
        List<XYPoint> myPoints = new ArrayList<>();

        // Convert to XYPoint objects
        for (int[] point : points) {
            myPoints.add(new XYPoint(point[0], point[1]));
        }

        // Sort by distance and take first k
        List<XYPoint> topK = myPoints.stream()
            .sorted(Comparator.comparingInt((XYPoint point) -> point.distSq))
            .limit(k)
            .collect(Collectors.toList());

        // Convert back to 2D array
        int[][] result = new int[k][2];
        for (int i = 0; i < k; i++) {
            XYPoint myPoint = topK.get(i);
            result[i][0] = myPoint.x;
            result[i][1] = myPoint.y;
        }

        return result;
    }
}

9.5.7 Solution - QuickSelect

9.5.7.1 Walkthrough

This solution uses QuickSelect, a partition-based algorithm that achieves O(n) average time with O(1) space - the optimal solution for this problem.

Core Strategy:

QuickSelect is similar to QuickSort but only recurses on one side. Instead of fully sorting, we partition the array around a pivot until exactly k elements are on the left (closest) side.

Algorithm Steps:

  1. Pick a pivot (random or deterministic)
  2. Partition: Rearrange array so:
    • All points closer than pivot are on the left
    • All points farther than pivot are on the right
    • Pivot is at its final sorted position
  3. Check partition result:
    • If pivot index \(=\) k: Done! First k elements are our answer
    • If pivot index \(<\) k: Recurse on right side (need more points)
    • If pivot index \(>\) k: Recurse on left side (have too many)

Key Insight: Why This Works

We don’t need the first k elements to be sorted among themselves - we just need them to be the k smallest. QuickSelect guarantees that after partitioning at index k-1: - Elements [0...k-1] are all closer than elements [k...n-1] - The order within [0...k-1] doesn’t matter (problem allows any order)

Visual Example for points = [[1,3],[-2,2],[2,2],[5,5]], k = 2:

Squared distances: [(1,3):10, (-2,2):8, (2,2):8, (5,5):50]

Iteration 1: Pick pivot index 2, value 8
  Partition: [(2,2):8, (-2,2):8, (1,3):10, (5,5):50]
  Pivot now at index 1 (0-indexed)
  k=2, pivot=1: pivot < k, recurse right

Iteration 2: Partition on right half starting at index 2
  Pick pivot index 2, value 10
  Already partitioned: [(2,2):8, (-2,2):8, (1,3):10, (5,5):50]
  Pivot at index 2
  k=2, pivot=2: Done!

Result: First 2 elements [(2,2), (-2,2)] (order doesn't matter)

Optimization: Randomized Pivot

Using a random pivot avoids worst-case O(n²) on sorted/reverse-sorted inputs: - Without randomization: O(n²) on sorted input - With randomization: O(n) average case with high probability

9.5.7.2 Analysis

  • Time Complexity:
    • Average case: O(n) - each partition reduces problem size by ~half
      • First partition: n operations
      • Second partition: n/2 operations
      • Third partition: n/4 operations
      • Sum: n + n/2 + n/4 + … = 2n = O(n)
    • Worst case: O(n²) - bad pivot choices every time (rare with randomization)
    • Best case: O(n) - optimal pivot every time
  • Space Complexity: O(1)
    • In-place partitioning
    • No extra arrays, heaps, or lists
    • Only uses a few variables for indices and swapping
    • This is optimal - cannot do better than O(1) space

Comparison with Other Solutions:

Solution Time (Avg) Time (Worst) Space When to Use
QuickSelect O(n) O(n²) O(1) Best overall - optimal time & space
Max Heap O(n log k) O(n log k) O(k) Small k, need stability
Full Sort O(n log n) O(n log n) O(n) Code simplicity, k ≈ n

When to Use QuickSelect: - Optimal performance needed (technical interviews, production code) - Large datasets where O(1) space matters - k varies widely (works well for all k values)

9.5.7.3 Implementation Steps

Helper Function: partition(points, left, right) 1. Choose random pivot index in [left, right] 2. Swap pivot to end position 3. Use two-pointer technique: - i = write position for elements \(<\) pivot - j = read position scanning all elements 4. For each element from left to right-1: - If distance \(\le\) pivot distance: swap to position i, increment i 5. Swap pivot back to position i (its final position) 6. Return i (pivot’s final index)

Main Function: quickSelect(points, left, right, k) 7. If left \(\ge\) right: return (base case) 8. Partition array and get pivot index 9. If pivot index \(=\) k-1: Done! First k elements are answer 10. If pivot index \(<\) k-1: Recurse on right: quickSelect(points, pivotIndex+1, right, k) 11. If pivot index \(>\) k-1: Recurse on left: quickSelect(points, left, pivotIndex-1, k)

Build Result: 12. Extract first k points from the partially partitioned array

9.5.7.4 Code - Java

class Solution {
    private Random rand = new Random();

    public int[][] kClosest(int[][] points, int k) {
        quickSelect(points, 0, points.length - 1, k);

        // First k elements are now the k closest (not sorted, but that's okay)
        return Arrays.copyOfRange(points, 0, k);
    }

    private void quickSelect(int[][] points, int left, int right, int k) {
        if (left >= right) {
            return;
        }

        // Partition and get pivot's final position
        int pivotIndex = partition(points, left, right);

        // Check if we've found exactly k elements on the left
        if (pivotIndex == k - 1) {
            return; // Perfect! First k elements are the answer
        } else if (pivotIndex < k - 1) {
            // Need more elements, recurse right
            quickSelect(points, pivotIndex + 1, right, k);
        } else {
            // Have too many elements, recurse left
            quickSelect(points, left, pivotIndex - 1, k);
        }
    }

    private int partition(int[][] points, int left, int right) {
        // Randomized pivot to avoid worst-case O(n²)
        int randomPivot = left + rand.nextInt(right - left + 1);
        swap(points, randomPivot, right); // Move pivot to end

        int pivotDist = distance(points[right]);
        int i = left; // Write position for elements < pivot

        // Partition: move all closer points to the left
        for (int j = left; j < right; j++) {
            if (distance(points[j]) <= pivotDist) {
                swap(points, i, j);
                i++;
            }
        }

        // Move pivot to its final position
        swap(points, i, right);
        return i;
    }

    private int distance(int[] point) {
        return point[0] * point[0] + point[1] * point[1];
    }

    private void swap(int[][] points, int i, int j) {
        int[] temp = points[i];
        points[i] = points[j];
        points[j] = temp;
    }
}

9.5.7.5 Code - Golang

func kClosest(points [][]int, k int) [][]int {
    quickSelect(points, 0, len(points)-1, k)

    // First k elements are now the k closest (not sorted, but that's okay)
    return points[:k]
}

func quickSelect(points [][]int, left, right, k int) {
    if left >= right {
        return
    }

    // Partition and get pivot's final position
    pivotIndex := partition(points, left, right)

    // Check if we've found exactly k elements on the left
    if pivotIndex == k-1 {
        return // Perfect! First k elements are the answer
    } else if pivotIndex < k-1 {
        // Need more elements, recurse right
        quickSelect(points, pivotIndex+1, right, k)
    } else {
        // Have too many elements, recurse left
        quickSelect(points, left, pivotIndex-1, k)
    }
}

func partition(points [][]int, left, right int) int {
    // Randomized pivot to avoid worst-case O(n²)
    randomPivot := left + rand.Intn(right-left+1)
    points[randomPivot], points[right] = points[right], points[randomPivot]

    pivotDist := distance(points[right])
    i := left // Write position for elements < pivot

    // Partition: move all closer points to the left
    for j := left; j < right; j++ {
        if distance(points[j]) <= pivotDist {
            points[i], points[j] = points[j], points[i]
            i++
        }
    }

    // Move pivot to its final position
    points[i], points[right] = points[right], points[i]
    return i
}

func distance(point []int) int {
    return point[0]*point[0] + point[1]*point[1]
}