5.17 Kth Smallest Element in a BST

5.17.1 Problem Metadata

5.17.2 Description

Given the root of a binary search tree, and an integer k, return the kth smallest value (1-indexed) of all the values of the nodes in the tree.

5.17.3 Examples

Example 1:

Input: root = [3,1,4,null,2], k = 1
Tree:
    3
   / \
  1   4
   \
    2
Output: 1
Explanation: The in-order traversal is [1, 2, 3, 4], and the 1st smallest is 1

Example 2:

Input: root = [5,3,6,2,4,null,null,1], k = 3
Tree:
      5
     / \
    3   6
   / \
  2   4
 /
1
Output: 3
Explanation: The in-order traversal is [1, 2, 3, 4, 5, 6], and the 3rd smallest is 3

5.17.4 Constraints

  • The number of nodes in the tree is n
  • \(1 \le k \le n \le 10^4\)
  • \(0 \le \text{Node.val} \le 10^4\)

Follow up: If the BST is modified often (i.e., we can do insert and delete operations) and you need to find the kth smallest frequently, how would you optimize?

5.17.5 Solution - In-Order Traversal (Recursive)

5.17.5.1 Walkthrough

The key insight is that in-order traversal of a BST visits nodes in sorted ascending order (left → root → right). Therefore, the kth smallest element is simply the kth node visited during in-order traversal.

Algorithm:

  1. Perform in-order DFS: Traverse left subtree first, then process current node, then traverse right subtree
  2. Count visited nodes: Maintain a counter that increments each time we visit a node
  3. Early termination: When counter equals k, we’ve found the kth smallest element - store it and stop further traversal
  4. Use class-level state: Store count and result as instance variables to track state across recursive calls

Why In-Order Traversal Works:

For a BST with the property that all left subtree values < root < all right subtree values, in-order traversal guarantees we visit nodes in ascending order:

Tree:    5
        / \
       3   6
      / \
     2   4
    /
   1

In-order visits: 1 → 2 → 3 → 4 → 5 → 6 (sorted!)
For k=3: Stop at the 3rd visit, which is node 3

Optimization: Early Termination

We don’t need to visit all n nodes - we can stop as soon as we find the kth element. Check result != -1 before recursing to avoid unnecessary traversals.

5.17.5.2 Analysis

  • Time Complexity: O(H + k) where H is tree height
    • We traverse down to the leftmost node: O(H)
    • Then visit k nodes in in-order: O(k)
    • Total: O(H + k), which is O(k) for balanced trees since H = O(log n)
    • Best case: O(k) when tree is balanced
    • Worst case: O(n) when k = n or tree is skewed
  • Space Complexity: O(H) for recursion stack
    • O(log n) for balanced tree
    • O(n) for skewed tree

5.17.5.3 Implementation Steps

  1. Initialize class-level variables: count = 0 and result = -1
  2. Call in-order traversal helper starting from root
  3. In the helper function:
    • Base case: if node is null or result already found, return
    • Recurse left subtree (visit smaller values first)
    • Increment count and check if count equals k
    • If count equals k, store node value in result and return
    • Recurse right subtree (visit larger values)
  4. Return the result

5.17.5.4 Code - Java

class Solution {
    private int count = 0;
    private int result = -1;

    public int kthSmallest(TreeNode root, int k) {
        inorder(root, k);
        return result;
    }

    private void inorder(TreeNode node, int k) {
        if (node == null || result != -1) {
            return;
        }

        // Left: visit all smaller values
        inorder(node.left, k);

        // Current: process this node
        count++;
        if (count == k) {
            result = node.val;
            return;
        }

        // Right: visit all larger values
        inorder(node.right, k);
    }
}

5.17.6 Solution - Iterative with Stack

5.17.6.1 Walkthrough

An iterative approach using an explicit stack to simulate in-order traversal. This avoids recursion overhead and makes the traversal process more explicit.

Algorithm:

  1. Push all left children: Starting from root, push nodes while moving left until we reach null
  2. Pop and process: Pop a node from the stack (this is the next smallest value)
  3. Decrement k: Each popped node represents one element in sorted order
  4. Check if done: When k equals 0, we’ve found the answer
  5. Move to right subtree: After processing a node, move to its right child and repeat

Why This Works:

The stack ensures we visit the leftmost (smallest) unvisited node next. After visiting a node, we explore its right subtree, which may contain values between this node and its parent.

Tree:    3
        / \
       1   4
        \
         2

Iteration steps:
1. Push 3, 1 (going left)
2. Pop 1 (1st smallest), k=1, continue
3. Push 2 (1.right)
4. Pop 2 (2nd smallest), k=0 → Return 2

5.17.6.2 Analysis

  • Time Complexity: O(H + k)
    • Same as recursive: traverse to leftmost node (H), then visit k nodes
  • Space Complexity: O(H) for the explicit stack
    • More predictable than recursion (no call stack overhead)

5.17.6.3 Code - Java

class Solution {
    public int kthSmallest(TreeNode root, int k) {
        Deque<TreeNode> stack = new ArrayDeque<>();
        TreeNode current = root;

        while (current != null || !stack.isEmpty()) {
            // Push all left children
            while (current != null) {
                stack.push(current);
                current = current.left;
            }

            // Pop the next smallest node
            current = stack.pop();
            k--;

            // Found the kth smallest
            if (k == 0) {
                return current.val;
            }

            // Move to right subtree
            current = current.right;
        }

        return -1; // Should never reach here given constraints
    }
}

Implementation Notes:

  • Cleaner loop structure: The outer while loop continues until we’ve explored all necessary nodes
  • No class-level state: All state is local to the method
  • Better for production: Iterative solutions are often easier to debug and reason about