Kth Smallest element in a sorted matrix


Kth-smallest-element-in-a-sorted-matrix


Problem Statement

You are given an n x n matrix where each row and column is sorted in ascending order. You need to find the kth smallest element in the matrix.

Note: The kth smallest element refers to the kth element in the sorted sequence, not necessarily a distinct one.


Examples

Example 1:

Input:
matrix = [[1,5,9],[10,11,13],[12,13,15]], k = 8
Output: 13
Explanation: The elements in order: [1, 5, 9, 10, 11, 12, 13, 13, 15]. The 8th smallest is 13.

Example 2:

Input:
matrix = [[-5]], k = 1
Output: -5


Constraints


Intuition

We are given a row-wise and column-wise sorted matrix, and we are asked to find the kth smallest element.

Naively, we could flatten the matrix and sort it — but that would require O(n²) space and O(n² log n²) time.

However, due to the sorted nature of the matrix, we can apply binary search on value space (not index space), from the minimum to maximum value in the matrix.

At each midpoint value, we count how many elements are <= mid.


Approach

  1. Let low = matrix[0][0], high = matrix[n-1][n-1] (minimum and maximum values).

  2. While low < high:

    • Compute mid = low + (high - low) / 2.

    • Count the number of elements <= mid using a helper function.

    • If count is less than k, move low to mid + 1.

    • Else, move high to mid.

  3. When the loop ends, low is the kth smallest element.

Counting Strategy (Helper Function):


Java Code

class Solution {
    public int kthSmallest(int[][] matrix, int k) {
        int beg = matrix[0][0];
        int end = matrix[matrix.length - 1][matrix.length - 1];

        while (beg < end) {
            int mid = beg + (end - beg) / 2;
            int count = lessEqual(matrix, mid);
            if (count < k) {
                beg = mid + 1;
            } else {
                end = mid;
            }
        }

        return beg;
    }

    private int lessEqual(int[][] matrix, int target) {
        int count = 0;
        int n = matrix.length;
        int row = n - 1;
        int col = 0;

        while (row >= 0 && col < n) {
            if (matrix[row][col] > target) {
                row--;
            } else {
                count += row + 1;
                col++;
            }
        }

        return count;
    }
}

Time and Space Complexity


Dry Run

Input: matrix = [[1,5,9],[10,11,13],[12,13,15]], k = 8
Range: beg = 1, end = 15


Conclusion

This is a classic application of binary search on values rather than indices. The trick lies in efficiently counting elements <= mid in the matrix, leveraging its row and column sorted property.

This approach keeps both time and memory usage optimal, which is critical for large matrices.