9.5 K Closest Points to Origin
9.5.1 Problem Metadata
- Platform: LeetCode
- Problem ID: 973
- Difficulty: Medium
- URL: https://leetcode.com/problems/k-closest-points-to-origin/
- Tags: Grind 75, NeetCode 150
- Techniques: Heap, QuickSelect, Sorting, Array
9.5.2 Description
Given an array of points where points[i] = [xi, yi] represents a point on the X-Y plane and an integer k, return the k closest points to the origin (0, 0).
The distance between two points on the X-Y plane is the Euclidean distance (i.e., \(\sqrt{(x_1 - x_2)^2 + (y_1 - y_2)^2}\)).
You may return the answer in any order. The answer is guaranteed to be unique (except for the order that it is in).
9.5.3 Examples
Example 1:
Input: points = [[1,3],[-2,2]], k = 1
Output: [[-2,2]]
Explanation:
The distance between (1, 3) and the origin is sqrt(10).
The distance between (-2, 2) and the origin is sqrt(8).
Since sqrt(8) < sqrt(10), (-2, 2) is closer to the origin.
We only want the closest k = 1 points from the origin, so the answer is just [[-2,2]].
Example 2:
Input: points = [[3,3],[5,-1],[-2,4]], k = 2
Output: [[3,3],[-2,4]]
Explanation: The answer [[-2,4],[3,3]] would also be accepted.
9.5.5 Solution - Max Heap
9.5.5.1 Walkthrough
This solution uses a max heap to efficiently find the k closest points without fully sorting the array.
Core Strategy:
Instead of sorting all n points, maintain a max heap of size k that always contains the k closest points seen so far. The heap is ordered by distance (max at top), allowing us to quickly discard farther points.
Key Insight: Why Max Heap, Not Min Heap?
- Max heap keeps the farthest of the k closest points at the top
- When we encounter a new point:
- If heap size \(<\) k: add it (we need more points)
- If heap size \(=\) k and new point is closer than heap top: remove top (farthest), add new point
- Otherwise: skip it (all k current points are closer)
- This maintains exactly k closest points with O(log k) operations per point
Distance Optimization:
Since we only need to compare distances (not compute actual values), we use squared distance to avoid expensive sqrt() operations:
- Distance: \(\sqrt{x^2 + y^2}\)
- Squared distance: \(x^2 + y^2\) (sufficient for comparisons since sqrt is monotonic)
Step-by-step execution for points = [[1,3],[-2,2],[2,2]], k = 2:
Distances (squared):
- (1,3): 1² + 3² = 10
- (-2,2): 4 + 4 = 8
- (2,2): 4 + 4 = 8
Processing:
1. Add (1,3) dist=10: Heap = [(1,3):10] size=1
2. Add (-2,2) dist=8: Heap = [(1,3):10, (-2,2):8] size=2 (reached k)
3. Process (2,2) dist=8:
- Heap top = (1,3):10 > 8
- Remove (1,3):10
- Add (2,2):8
- Heap = [(-2,2):8, (2,2):8] size=2
Result: [[-2,2], [2,2]] or [[2,2], [-2,2]] (order doesn't matter)
Why This Works:
- After processing all points, the heap contains exactly the k smallest distances
- Max heap property ensures we can always access and remove the farthest point in O(log k)
- Total operations: n insertions × O(log k) = O(n log k)
9.5.5.2 Analysis
- Time Complexity: O(n log k)
- Process each of n points: O(n)
- Each point may trigger heap insert/remove: O(log k)
- Building result array: O(k)
- Total: O(n log k) - much better than O(n log n) when k << n
- Example: If n=10,000 and k=10, this is ~133,000 operations vs ~133,000,000 for full sort
- Space Complexity: O(k)
- Max heap stores at most k points
- Result array: O(k) (required output)
- This is optimal - we only need to remember k points
Comparison with Full Sort: - When k is small (k << n): Max heap wins significantly - k=1, n=10,000: O(10,000 × 1) vs O(10,000 × 13) = 13× faster - When k \(\approx\) n: Both approaches similar (O(n log n)) - Space: Heap uses O(k) vs O(n) for sorting
9.5.5.3 Implementation Steps
Setup:
1. Create helper class XYPoint to store:
- Original coordinates (x, y)
- Precomputed squared distance: distSq = x² + y²
2. Create max heap ordered by distSq (descending)
Main Algorithm:
3. For each point in input array:
1. Wrap point as XYPoint object
2. Add to max heap
3. If heap size exceeds k:
- Remove heap top (the farthest point)
- This maintains exactly k closest points
Build Result:
4. Extract all k points from heap into result array
- Poll heap k times
- Convert each XYPoint back to [x, y] format
5. Return result array
Note: Order doesn’t matter per problem constraints, so heap extraction order is acceptable.
9.5.5.4 Code - Java
class Solution {
private static class XYPoint {
public final int x;
public final int y;
public final int distSq;
public XYPoint(int xi, int yi) {
this.x = xi;
this.y = yi;
// Use squared distance to avoid expensive sqrt()
this.distSq = x * x + y * y;
}
}
public int[][] kClosest(int[][] points, int k) {
// Max heap ordered by distance (largest distance at top)
PriorityQueue<XYPoint> maxHeap = new PriorityQueue<>(
Comparator.comparingInt((XYPoint point) -> point.distSq).reversed()
);
// Process each point
for (int[] point : points) {
maxHeap.add(new XYPoint(point[0], point[1]));
// Keep only k closest points
if (maxHeap.size() > k) {
maxHeap.poll(); // Remove farthest point
}
}
// Build result array from heap
int[][] result = new int[k][2];
for (int i = 0; i < k; i++) {
XYPoint myPoint = maxHeap.poll();
result[i][0] = myPoint.x;
result[i][1] = myPoint.y;
}
return result;
}
}9.5.6 Solution - Sort + Stream API
9.5.6.1 Walkthrough
This solution uses Java’s Stream API to sort all points by distance and select the first k.
Approach:
1. Convert all points to XYPoint objects with precomputed squared distances
2. Stream the list and sort by distSq (ascending)
3. Limit to first k elements
4. Collect and convert back to 2D array
Trade-offs: - Simpler, more readable code - Less efficient: sorts all n points even when k is small - Good for interviews if code clarity is prioritized over optimal complexity
9.5.6.2 Analysis
- Time Complexity: O(n log n)
- Creating list: O(n)
- Sorting all points: O(n log n)
- Taking first k: O(k)
- Total dominated by sorting: O(n log n)
- Space Complexity: O(n)
- List stores all n points
- Stream intermediate operations: O(n)
- Result array: O(k)
When to Use This: - Code readability is prioritized - k is close to n (no benefit from heap optimization) - Quick prototyping or interviews with relaxed time constraints
9.5.6.3 Code - Java
class Solution {
private static class XYPoint {
public final int x;
public final int y;
public final int distSq;
public XYPoint(int xi, int yi) {
this.x = xi;
this.y = yi;
this.distSq = x * x + y * y;
}
}
public int[][] kClosest(int[][] points, int k) {
List<XYPoint> myPoints = new ArrayList<>();
// Convert to XYPoint objects
for (int[] point : points) {
myPoints.add(new XYPoint(point[0], point[1]));
}
// Sort by distance and take first k
List<XYPoint> topK = myPoints.stream()
.sorted(Comparator.comparingInt((XYPoint point) -> point.distSq))
.limit(k)
.collect(Collectors.toList());
// Convert back to 2D array
int[][] result = new int[k][2];
for (int i = 0; i < k; i++) {
XYPoint myPoint = topK.get(i);
result[i][0] = myPoint.x;
result[i][1] = myPoint.y;
}
return result;
}
}9.5.7 Solution - QuickSelect
9.5.7.1 Walkthrough
This solution uses QuickSelect, a partition-based algorithm that achieves O(n) average time with O(1) space - the optimal solution for this problem.
Core Strategy:
QuickSelect is similar to QuickSort but only recurses on one side. Instead of fully sorting, we partition the array around a pivot until exactly k elements are on the left (closest) side.
Algorithm Steps:
- Pick a pivot (random or deterministic)
- Partition: Rearrange array so:
- All points closer than pivot are on the left
- All points farther than pivot are on the right
- Pivot is at its final sorted position
- Check partition result:
- If pivot index \(=\) k: Done! First k elements are our answer
- If pivot index \(<\) k: Recurse on right side (need more points)
- If pivot index \(>\) k: Recurse on left side (have too many)
Key Insight: Why This Works
We don’t need the first k elements to be sorted among themselves - we just need them to be the k smallest. QuickSelect guarantees that after partitioning at index k-1:
- Elements [0...k-1] are all closer than elements [k...n-1]
- The order within [0...k-1] doesn’t matter (problem allows any order)
Visual Example for points = [[1,3],[-2,2],[2,2],[5,5]], k = 2:
Squared distances: [(1,3):10, (-2,2):8, (2,2):8, (5,5):50]
Iteration 1: Pick pivot index 2, value 8
Partition: [(2,2):8, (-2,2):8, (1,3):10, (5,5):50]
Pivot now at index 1 (0-indexed)
k=2, pivot=1: pivot < k, recurse right
Iteration 2: Partition on right half starting at index 2
Pick pivot index 2, value 10
Already partitioned: [(2,2):8, (-2,2):8, (1,3):10, (5,5):50]
Pivot at index 2
k=2, pivot=2: Done!
Result: First 2 elements [(2,2), (-2,2)] (order doesn't matter)
Optimization: Randomized Pivot
Using a random pivot avoids worst-case O(n²) on sorted/reverse-sorted inputs: - Without randomization: O(n²) on sorted input - With randomization: O(n) average case with high probability
9.5.7.2 Analysis
- Time Complexity:
- Average case: O(n) - each partition reduces problem size by ~half
- First partition: n operations
- Second partition: n/2 operations
- Third partition: n/4 operations
- Sum: n + n/2 + n/4 + … = 2n = O(n)
- Worst case: O(n²) - bad pivot choices every time (rare with randomization)
- Best case: O(n) - optimal pivot every time
- Average case: O(n) - each partition reduces problem size by ~half
- Space Complexity: O(1)
- In-place partitioning
- No extra arrays, heaps, or lists
- Only uses a few variables for indices and swapping
- This is optimal - cannot do better than O(1) space
Comparison with Other Solutions:
| Solution | Time (Avg) | Time (Worst) | Space | When to Use |
|---|---|---|---|---|
| QuickSelect | O(n) | O(n²) | O(1) | Best overall - optimal time & space |
| Max Heap | O(n log k) | O(n log k) | O(k) | Small k, need stability |
| Full Sort | O(n log n) | O(n log n) | O(n) | Code simplicity, k ≈ n |
When to Use QuickSelect: - Optimal performance needed (technical interviews, production code) - Large datasets where O(1) space matters - k varies widely (works well for all k values)
9.5.7.3 Implementation Steps
Helper Function: partition(points, left, right)
1. Choose random pivot index in [left, right]
2. Swap pivot to end position
3. Use two-pointer technique:
- i = write position for elements \(<\) pivot
- j = read position scanning all elements
4. For each element from left to right-1:
- If distance \(\le\) pivot distance: swap to position i, increment i
5. Swap pivot back to position i (its final position)
6. Return i (pivot’s final index)
Main Function: quickSelect(points, left, right, k)
7. If left \(\ge\) right: return (base case)
8. Partition array and get pivot index
9. If pivot index \(=\) k-1: Done! First k elements are answer
10. If pivot index \(<\) k-1: Recurse on right: quickSelect(points, pivotIndex+1, right, k)
11. If pivot index \(>\) k-1: Recurse on left: quickSelect(points, left, pivotIndex-1, k)
Build Result: 12. Extract first k points from the partially partitioned array
9.5.7.4 Code - Java
class Solution {
private Random rand = new Random();
public int[][] kClosest(int[][] points, int k) {
quickSelect(points, 0, points.length - 1, k);
// First k elements are now the k closest (not sorted, but that's okay)
return Arrays.copyOfRange(points, 0, k);
}
private void quickSelect(int[][] points, int left, int right, int k) {
if (left >= right) {
return;
}
// Partition and get pivot's final position
int pivotIndex = partition(points, left, right);
// Check if we've found exactly k elements on the left
if (pivotIndex == k - 1) {
return; // Perfect! First k elements are the answer
} else if (pivotIndex < k - 1) {
// Need more elements, recurse right
quickSelect(points, pivotIndex + 1, right, k);
} else {
// Have too many elements, recurse left
quickSelect(points, left, pivotIndex - 1, k);
}
}
private int partition(int[][] points, int left, int right) {
// Randomized pivot to avoid worst-case O(n²)
int randomPivot = left + rand.nextInt(right - left + 1);
swap(points, randomPivot, right); // Move pivot to end
int pivotDist = distance(points[right]);
int i = left; // Write position for elements < pivot
// Partition: move all closer points to the left
for (int j = left; j < right; j++) {
if (distance(points[j]) <= pivotDist) {
swap(points, i, j);
i++;
}
}
// Move pivot to its final position
swap(points, i, right);
return i;
}
private int distance(int[] point) {
return point[0] * point[0] + point[1] * point[1];
}
private void swap(int[][] points, int i, int j) {
int[] temp = points[i];
points[i] = points[j];
points[j] = temp;
}
}9.5.7.5 Code - Golang
func kClosest(points [][]int, k int) [][]int {
quickSelect(points, 0, len(points)-1, k)
// First k elements are now the k closest (not sorted, but that's okay)
return points[:k]
}
func quickSelect(points [][]int, left, right, k int) {
if left >= right {
return
}
// Partition and get pivot's final position
pivotIndex := partition(points, left, right)
// Check if we've found exactly k elements on the left
if pivotIndex == k-1 {
return // Perfect! First k elements are the answer
} else if pivotIndex < k-1 {
// Need more elements, recurse right
quickSelect(points, pivotIndex+1, right, k)
} else {
// Have too many elements, recurse left
quickSelect(points, left, pivotIndex-1, k)
}
}
func partition(points [][]int, left, right int) int {
// Randomized pivot to avoid worst-case O(n²)
randomPivot := left + rand.Intn(right-left+1)
points[randomPivot], points[right] = points[right], points[randomPivot]
pivotDist := distance(points[right])
i := left // Write position for elements < pivot
// Partition: move all closer points to the left
for j := left; j < right; j++ {
if distance(points[j]) <= pivotDist {
points[i], points[j] = points[j], points[i]
i++
}
}
// Move pivot to its final position
points[i], points[right] = points[right], points[i]
return i
}
func distance(point []int) int {
return point[0]*point[0] + point[1]*point[1]
}