14.24 Count Square Submatrices with All Ones

14.24.1 Problem Metadata

14.24.2 Description

Given a m x n matrix of ones and zeros, return how many square submatrices have all ones.

14.24.3 Examples

Example 1:

Input: matrix =
[
  [0,1,1,1],
  [1,1,1,1],
  [0,1,1,1]
]
Output: 15
Explanation:
There are 10 squares of side 1.
There are 4 squares of side 2.
There is 1 square of side 3.
Total number of squares = 10 + 4 + 1 = 15.

Example 2:

Input: matrix =
[
  [1,0,1],
  [1,1,0],
  [1,1,0]
]
Output: 7
Explanation:
There are 6 squares of side 1.
There is 1 square of side 2.
Total number of squares = 6 + 1 = 7.

14.24.4 Constraints

  • 1 <= arr.length <= 300
  • 1 <= arr[0].length <= 300
  • 0 <= arr[i][j] <= 1

14.24.5 Solution - Dynamic Programming (In-Place)

14.24.5.1 Walkthrough

This problem is closely related to Maximal Square (LeetCode 221), but instead of finding the largest square, we need to count all squares. The key insight is:

For each cell (r, c) containing 1, dp[r][c] represents the side length of the largest square with (r, c) as the bottom-right corner. This value also tells us HOW MANY squares end at this position.

For example: - If dp[r][c] = 3, it means there are 3 squares ending at (r, c): - One 1×1 square - One 2×2 square - One 3×3 square

The DP recurrence is the same as Maximal Square:

dp[r][c] = min(dp[r-1][c], dp[r][c-1], dp[r-1][c-1]) + 1

The total count is the sum of all dp[r][c] values.

14.24.5.2 Visual Example

For the input matrix:

[0, 1, 1, 1]
[1, 1, 1, 1]
[0, 1, 1, 1]

Processing order (row by row, left to right):

Step 1: (0,0)=0, skip
Step 2: (0,1)=1, result=1, matrix: [0, 1, 1, 1]
                                     [1, 1, 1, 1]
                                     [0, 1, 1, 1]

Step 3: (0,2)=1, result=2, matrix: [0, 1, 1, 1]
                                     [1, 1, 1, 1]
                                     [0, 1, 1, 1]

Step 4: (0,3)=1, result=3, matrix: [0, 1, 1, 1]
                                     [1, 1, 1, 1]
                                     [0, 1, 1, 1]

Step 5: (1,0)=1, result=4, matrix: [0, 1, 1, 1]
                                     [1, 1, 1, 1]
                                     [0, 1, 1, 1]

Step 6: (1,1)=1, result=5, matrix: [0, 1, 1, 1]
                    min(1,1,0)+1=1  [1, 1, 1, 1]
                    no change       [0, 1, 1, 1]

Step 7: (1,2)=1, result=6, matrix: [0, 1, 1, 1]
                    min(1,1,1)+1=2  [1, 1, 2, 1]
                    result += 1 = 7 [0, 1, 1, 1]

Step 8: (1,3)=1, result=8, matrix: [0, 1, 1, 1]
                    min(1,2,1)+1=2  [1, 1, 2, 2]
                    result += 1 = 9 [0, 1, 1, 1]

Step 9: (2,0)=0, skip

Step 10: (2,1)=1, result=10, matrix: [0, 1, 1, 1]
                                      [1, 1, 2, 2]
                                      [0, 1, 1, 1]

Step 11: (2,2)=1, result=11, matrix: [0, 1, 1, 1]
                     min(2,1,1)+1=2   [1, 1, 2, 2]
                     result += 1 = 12 [0, 1, 2, 1]

Step 12: (2,3)=1, result=13, matrix: [0, 1, 1, 1]
                     min(2,2,2)+1=3   [1, 1, 2, 2]
                     result += 2 = 15 [0, 1, 2, 3]

Final result: 15

14.24.5.3 Key Insight

The elegant part of this solution is that the DP value at each cell directly tells us the count of squares ending at that position. This is why we can simply sum all DP values to get the total count.

14.24.5.4 Analysis

  • Time Complexity: O(m × n) - Single pass through the matrix
  • Space Complexity: O(1) - Modifies matrix in-place (or O(m × n) if using separate DP table)

14.24.5.5 Implementation Steps

  1. Initialize result = 0 to track total count
  2. For each cell (r, c):
    • If matrix[r][c] == 1:
      • Add 1 to result (every 1 forms at least a 1×1 square)
      • If not in first row or column:
        • Calculate matrix[r][c] = min(top, left, diagonal) + 1
        • Add matrix[r][c] - 1 to result (additional squares of size 2×2, 3×3, etc.)
  3. Return result

14.24.5.6 Code - Java

class Solution {
    public int countSquares(int[][] matrix) {
        int rows = matrix.length;
        int cols = matrix[0].length;
        int result = 0;

        for (int r = 0; r < rows; r++) {
            for (int c = 0; c < cols; c++) {
                // Only process cells containing 1
                if (matrix[r][c] == 1) {
                    // Each 1 can form at least a 1x1 square
                    result++;

                    // For cells not in first row/column, apply DP transition
                    if (r != 0 && c != 0) {
                        // DP transition: current square size is limited by
                        // the minimum of the three neighbors + 1
                        int min = Math.min(matrix[r - 1][c], matrix[r][c - 1]);
                        min = Math.min(min, matrix[r - 1][c - 1]);
                        matrix[r][c] = min + 1;

                        // Add additional squares (2x2, 3x3, ..., NxN)
                        // If matrix[r][c] = 3, we already counted the 1x1,
                        // so add 2 more for the 2x2 and 3x3 squares
                        result += matrix[r][c] - 1;
                    }
                }
            }
        }

        return result;
    }
}