Kth Smallest element in a sorted matrix
Problem Link
Kth-smallest-element-in-a-sorted-matrix
Problem Statement
You are given an n x n matrix where each row and column is sorted in ascending order. You need to find the kth smallest element in the matrix.
Note: The kth smallest element refers to the kth element in the sorted sequence, not necessarily a distinct one.
Examples
Example 1:
Input:
matrix = [[1,5,9],[10,11,13],[12,13,15]], k = 8
Output: 13
Explanation: The elements in order: [1, 5, 9, 10, 11, 12, 13, 13, 15]. The 8th smallest is 13.
Example 2:
Input:
matrix = [[-5]], k = 1
Output: -5
Constraints
-
n == matrix.length == matrix[i].length -
1 <= n <= 300 -
-10^9 <= matrix[i][j] <= 10^9 -
All rows and columns are sorted in non-decreasing order
-
1 <= k <= n^2
Intuition
We are given a row-wise and column-wise sorted matrix, and we are asked to find the kth smallest element.
Naively, we could flatten the matrix and sort it — but that would require O(n²) space and O(n² log n²) time.
However, due to the sorted nature of the matrix, we can apply binary search on value space (not index space), from the minimum to maximum value in the matrix.
At each midpoint value, we count how many elements are <= mid.
-
If the count is less than
k, the kth smallest is in the right half. -
Else, move to the left half (including mid).
Approach
-
Let
low = matrix[0][0],high = matrix[n-1][n-1](minimum and maximum values). -
While
low < high:-
Compute
mid = low + (high - low) / 2. -
Count the number of elements
<= midusing a helper function. -
If count is less than
k, movelowtomid + 1. -
Else, move
hightomid.
-
-
When the loop ends,
lowis the kth smallest element.
Counting Strategy (Helper Function):
-
Start from the bottom-left corner.
-
For each column, if the value is
<= mid, then all elements above are also<= mid. Add(row + 1)to count and move right. -
Else, move up.
Java Code
class Solution {
public int kthSmallest(int[][] matrix, int k) {
int beg = matrix[0][0];
int end = matrix[matrix.length - 1][matrix.length - 1];
while (beg < end) {
int mid = beg + (end - beg) / 2;
int count = lessEqual(matrix, mid);
if (count < k) {
beg = mid + 1;
} else {
end = mid;
}
}
return beg;
}
private int lessEqual(int[][] matrix, int target) {
int count = 0;
int n = matrix.length;
int row = n - 1;
int col = 0;
while (row >= 0 && col < n) {
if (matrix[row][col] > target) {
row--;
} else {
count += row + 1;
col++;
}
}
return count;
}
}
Time and Space Complexity
-
Time Complexity:
O(n * log(max - min))
Binary search over value range (log range) and each count is O(n). -
Space Complexity:
O(1)
No extra space used apart from variables.
Dry Run
Input: matrix = [[1,5,9],[10,11,13],[12,13,15]], k = 8
Range: beg = 1, end = 15
-
mid = 8 → count = 2 → beg = 9
-
mid = 12 → count = 6 → beg = 13
-
mid = 14 → count = 8 → end = 13
-
mid = 13 → count = 8 → end = 13
Final value:13→ 8th smallest.
Conclusion
This is a classic application of binary search on values rather than indices. The trick lies in efficiently counting elements <= mid in the matrix, leveraging its row and column sorted property.
This approach keeps both time and memory usage optimal, which is critical for large matrices.