Kth smallest element in a row-wise and column-wise sorted 2D array
Given an n × n matrix mat[][] where each row and column is sorted in non-decreasing order, find the kth smallest element, where k lies in the range [1, n²].
Example:
Input: mat[][] = [[10, 20, 30, 40], k = 3
[15, 25, 35, 45],
[24, 29, 37, 48],
[32, 33, 39, 50]]
Output: 20
Explanation: The sorted order is [10, 15, 20, ...]; the 3rd element is 20.Input: mat[][] = [[10, 20, 30, 40], k = 7
[15, 25, 35, 45],
[24, 29, 37, 48],
[32, 33, 39, 50]]
Output: 30
Explanation: The sorted order is [10, 15, 20, 24, 25, 29, 30, ...]; the 7th element is 30.
Table of Content
[Naive Approach] Using Sorting - O(n2 × log n2) Time and O(n) Space
Initialize a 1-dimensional array of size n*n to store all the elements of the mat[][] , we will get our kth minimum element by sorting the 1-dimensional array in non-decreasing order.
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
int kthSmallest(vector<vector<int>>& mat, int k) {
int n = mat.size();
// create a vector to store all elements
vector<int> arr;
// store all elements of the mat
// into the array
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
arr.push_back(mat[i][j]);
}
}
// sort the array
sort(arr.begin(), arr.end());
// return the kth smallest element
// (0-based index, hence k-1)
return arr[k - 1];
}
int main() {
vector<vector<int>> mat =
{{10, 20, 30, 40},
{15, 25, 35, 45},
{24, 29, 37, 48},
{32, 33, 39, 50 }};
int k = 3;
int result = kthSmallest(mat, k);
cout << result << endl;
return 0;
}
import java.util.ArrayList;
import java.util.Collections;
class GfG {
static int kthSmallest(int[][] mat, int k) {
int n = mat.length;
// create an ArrayList to store all elements
ArrayList<Integer> arr = new ArrayList<Integer>();
// store all elements of the mat into the array
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
arr.add(mat[i][j]);
}
}
// sort the array
Collections.sort(arr);
// return the kth smallest element
// (0-based index, hence k-1)
return arr.get(k - 1);
}
public static void main(String[] args) {
int[][] mat =
{{10, 20, 30, 40},
{15, 25, 35, 45},
{24, 29, 37, 48},
{32, 33, 39, 50}};
int k = 3;
int result = kthSmallest(mat, k);
System.out.println(result);
}
}
def kthSmallest(mat, k):
n = len(mat)
# create a list to store all elements
arr = []
# store all elements of the mat into the array
for i in range(n):
for j in range(n):
arr.append(mat[i][j])
# sort the array
arr.sort()
# return the kth smallest element
# (0-based index, hence k-1)
return arr[k - 1]
if __name__ == "__main__":
mat = [
[10, 20, 30, 40],
[15, 25, 35, 45],
[24, 29, 37, 48],
[32, 33, 39, 50]
]
k = 3
result = kthSmallest(mat, k)
print(result)
using System;
using System.Collections.Generic;
class GfG {
// function to find the kth smallest
// element in a sorted 2D mat
static int kthSmallest(int[,] mat, int k) {
int n = mat.GetLength(0);
// create a List to store all elements
List<int> arr = new List<int>();
// store all elements of the mat
// into the array
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
arr.Add(mat[i,j]);
}
}
// sort the array
arr.Sort();
// return the kth smallest element
// (0-based index, hence k-1)
return arr[k - 1];
}
public static void Main(string[] args) {
int[,] mat =
{{10, 20, 30, 40},
{15, 25, 35, 45},
{24, 29, 37, 48},
{32, 33, 39, 50}};
int k = 3;
int result = kthSmallest(mat, k);
Console.WriteLine(result);
}
}
function kthSmallest(mat, k) {
let n = mat.length;
// create an array to store all elements
let arr = [];
// store all elements of the mat
// into the array
for (let i = 0; i < n; ++i) {
for (let j = 0; j < n; ++j) {
arr.push(mat[i][j]);
}
}
// sort the array
arr.sort((a, b) => a - b);
// return the kth smallest element
// (0-based index, hence k-1)
return arr[k - 1];
}
// Driver Code
let mat = [[10, 20, 30, 40],
[15, 25, 35, 45],
[24, 29, 37, 48],
[32, 33, 39, 50]];
let k = 3;
let result = kthSmallest(mat, k);
console.log(result);
Output
20
[Better Approach] Using Priority Queue - O(n2 × log k) Time and O(k) Space
The idea is to use a max-heap to store and maintain the track of k smallest elements in the heap. If the size of the heap exceeds more than k while inserting the elements , we will pop the top element from max-heap so as to maintain the size of k elements. After successful traversal in mat[][], the top element of the max-heap will be the kth minimum element.
#include <iostream>
#include <vector>
#include <queue> // for priority_queue
using namespace std;
int kthSmallest(vector<vector<int>>& mat, int k) {
int n = mat.size();
priority_queue<int> pq; // max-heap
// traverse all elements in the mat
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
int curr = mat[i][j];
// push the current element into the max-heap
pq.push(curr);
// if size exceeds k, remove the largest
if (pq.size() > k) {
pq.pop();
}
}
}
// top of the heap is the kth smallest element
return pq.top();
}
int main() {
vector<vector<int>> mat = {
{10, 20, 30, 40},
{15, 25, 35, 45},
{24, 29, 37, 48},
{32, 33, 39, 50}
};
int k = 3;
int result = kthSmallest(mat, k);
cout << result << endl;
return 0;
}
import java.util.*;
class GfG {
static int kthSmallest(int[][] mat, int k) {
int n = mat.length;
PriorityQueue<Integer> pq =
new PriorityQueue<>(Collections.reverseOrder());
// traverse all elements in the mat
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
int curr = mat[i][j];
// push the current element into the max-heap
pq.offer(curr);
// if the size of the max-heap exceeds k,
// remove the largest element
if (pq.size() > k) {
pq.poll();
}
}
}
// the top element of the max-heap
// is the kth smallest element
return pq.peek();
}
public static void main(String[] args) {
int[][] mat =
{{10, 20, 30, 40},
{15, 25, 35, 45},
{24, 29, 37, 48},
{32, 33, 39, 50}};
int k = 3;
int result = kthSmallest(mat, k);
System.out.println(result);
}
}
import heapq
def kthSmallest(mat, k):
n = len(mat)
pq = []
# traverse all elements in the mat
for i in range(n):
for j in range(n):
curr = mat[i][j]
# push the current element into the max-heap
heapq.heappush(pq, -curr)
# if the size of the max-heap exceeds k,
# remove the largest element
if len(pq) > k:
heapq.heappop(pq)
# the top element of the max-heap
# is the kth smallest element
return -pq[0]
if __name__ == "__main__":
mat = [
[10, 20, 30, 40],
[15, 25, 35, 45],
[24, 29, 37, 48],
[32, 33, 39, 50]]
k = 3
result = kthSmallest(mat, k)
print(result)
using System;
using System.Collections.Generic;
class GfG {
static int kthSmallest(int[,] mat, int k) {
int n = mat.GetLength(0);
PriorityQueue<int> pq =
new PriorityQueue<int>
(Comparer<int>.Create((a, b) => b.CompareTo(a)));
// traverse all elements in the mat
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
int curr = mat[i,j];
// push the current element into the max-heap
pq.Enqueue(curr);
// if the size of the max-heap exceeds k,
// remove the largest element
if (pq.Count > k) {
pq.Dequeue();
}
}
}
// the top element of the max-heap
// is the kth smallest element
return pq.Peek();
}
public static void Main(string[] args) {
int[,] mat =
{{10, 20, 30, 40},
{15, 25, 35, 45},
{24, 29, 37, 48},
{32, 33, 39, 50}};
int k = 3;
int result = kthSmallest(mat, k);
Console.WriteLine(result);
}
}
public class PriorityQueue<T> {
private List<T> data;
private IComparer<T> comparer;
public PriorityQueue(IComparer<T> comparer) {
this.data = new List<T>();
this.comparer = comparer;
}
public void Enqueue(T item) {
data.Add(item);
int ci = data.Count - 1;
while (ci > 0) {
int pi = (ci - 1) / 2;
if (comparer.Compare(data[ci], data[pi]) >= 0) break;
T tmp = data[ci]; data[ci] = data[pi]; data[pi] = tmp;
ci = pi;
}
}
public T Dequeue() {
int li = data.Count - 1;
T frontItem = data[0];
data[0] = data[li];
data.RemoveAt(li);
--li;
int pi = 0;
while (true) {
int ci = pi * 2 + 1;
if (ci > li) break;
int rc = ci + 1;
if (rc <= li && comparer.Compare(data[rc], data[ci]) < 0)
ci = rc;
if (comparer.Compare(data[pi], data[ci]) <= 0) break;
T tmp = data[pi]; data[pi] = data[ci]; data[ci] = tmp;
pi = ci;
}
return frontItem;
}
public T Peek() {
return data[0];
}
public int Count {
get { return data.Count; }
}
}
class MaxHeap {
constructor() {
this.heap = [];
}
push(val) {
this.heap.push(val);
this.bubbleUp(this.heap.length - 1);
}
pop() {
const max = this.heap[0];
const end = this.heap.pop();
if (this.heap.length > 0) {
this.heap[0] = end;
this.bubbleDown(0);
}
return max;
}
peek() {
return this.heap[0];
}
size() {
return this.heap.length;
}
bubbleUp(idx) {
const element = this.heap[idx];
while (idx > 0) {
const parentIdx = Math.floor((idx - 1) / 2);
const parent = this.heap[parentIdx];
if (element <= parent) break;
this.heap[idx] = parent;
this.heap[parentIdx] = element;
idx = parentIdx;
}
}
bubbleDown(idx) {
const length = this.heap.length;
const element = this.heap[idx];
while (true) {
const leftChildIdx = 2 * idx + 1;
const rightChildIdx = 2 * idx + 2;
let leftChild, rightChild;
let swap = null;
if (leftChildIdx < length) {
leftChild = this.heap[leftChildIdx];
if (leftChild > element) {
swap = leftChildIdx;
}
}
if (rightChildIdx < length) {
rightChild = this.heap[rightChildIdx];
if (
(swap === null && rightChild > element) ||
(swap !== null && rightChild > leftChild)
) {
swap = rightChildIdx;
}
}
if (swap === null) break;
this.heap[idx] = this.heap[swap];
this.heap[swap] = element;
idx = swap;
}
}
}
// function to find the kth smallest
// element in a sorted 2D mat
function kthSmallest(mat, k) {
const n = mat.length;
const pq = new MaxHeap();
// traverse all elements in the mat
for (let i = 0; i < n; ++i) {
for (let j = 0; j < n; ++j) {
const curr = mat[i][j];
// push the current element into the max-heap
pq.push(curr);
// if the size of the max-heap exceeds k,
// remove the largest element
if (pq.size() > k) {
pq.pop();
}
}
}
// the top element of the max-heap
// is the kth smallest element
return pq.peek();
}
const mat =
[[10, 20, 30, 40],
[15, 25, 35, 45],
[24, 29, 37, 48],
[32, 33, 39, 50]];
const k = 3;
const result = kthSmallest(mat, k);
console.log(result);
Output
20
[Efficient Approach] Binary Search on Answer
This approach uses binary search to iterate over possible solutions. As answer lies in the range from mat[0][0] to mat[n-1][n-1], So we do a binary search on this range and in each iteration determine the no of elements smaller than or equal to our current middle element.
Step by Step Implementation:
- Initialize a variable, say low equals to the mat[0][0] (minimum value of matrix).
- Initialize a variable, say high equals to the mat[n-1][n-1] (maximum value of matrix).
- Initialize ans to 0.
- Perform Binary Search on the range from low to high:
=> Calculate the midpoint in the range say mid.
=> If the countSmallerEqual(function which will return the count of elements less than or equal to mid) is less than k, update low to mid+ 1.
=> if the returned value is greater or equal to k , this can be our possible ans. So, update ans to mid and narrow the search range by setting high to mid - 1. - countSmallerEqual (helper function that counts the number of elements in the matrix less than or equal to the given mid.)
=> initialize a pointer, say row and col points to 0 and n-1 respectively. And a variable count = 0.
=> if the mat[row][col] is greater than mid, move left in the matrix by decrementing col.
=> if the mat[row][col] is less than or equal to mid, increment the count as by col + 1 and move down in the matrix by incrementing row.
#include <iostream>
#include <vector>
using namespace std;
// function to count the number of elements
// less than or equal to x
int countSmallerEqual(vector<vector<int>>& mat, int x) {
int n = mat.size();
int count = 0;
int row = 0;
int col = n - 1;
// traverse from the top-right corner
while (row < n && col >= 0) {
if (mat[row][col] <= x) {
// if current element is less than
// or equal to x, all elements in this
// row up to the current column are <= x
count += (col + 1);
row++;
}
else{
// move left in the mat
col--;
}
}
return count;
}
// function to find the kth smallest
// element in a sorted 2D mat
int kthSmallest(vector<vector<int>>& mat, int k) {
int n = mat.size();
int low = mat[0][0];
int high = mat[n - 1][n - 1];
int ans = 0;
while (low <= high) {
int mid = low + (high - low) / 2;
// count elements less than or equal to mid
int count = countSmallerEqual(mat, mid);
if (count < k) {
// if there are less than k elements
// <= mid, the kth smallest is larger
low = mid + 1;
} else {
// otherwise, mid might be the answer,
// but we need to check for smaller values
ans = mid;
high = mid - 1;
}
}
return ans;
}
int main() {
vector<vector<int>> mat =
{{10, 20, 30, 40},
{15, 25, 35, 45},
{24, 29, 37, 48},
{32, 33, 39, 50 }};
int k = 3;
int result = kthSmallest(mat, k);
cout << result << endl;
return 0;
}
class GfG {
// function to count the number of elements
// less than or equal to x
static int countSmallerEqual(int[][] mat, int x) {
int n = mat.length;
int count = 0;
int row = 0;
int col = n - 1;
// traverse from the top-right corner
while (row < n && col >= 0) {
if (mat[row][col] <= x) {
// if current element is less than
// or equal to x, all elements in this
// row up to the current column are <= x
count += (col + 1);
row++;
}
else{
// move left in the mat
col--;
}
}
return count;
}
// function to find the kth smallest
// element in a sorted 2D mat
static int kthSmallest(int[][] mat, int k) {
int n = mat.length;
int low = mat[0][0];
int high = mat[n - 1][n - 1];
int ans = 0;
while (low <= high) {
int mid = low + (high - low) / 2;
// count elements less than or equal to mid
int count = countSmallerEqual(mat, mid);
if (count < k) {
// if there are less than k elements
// <= mid, the kth smallest is larger
low = mid + 1;
}
else {
// otherwise, mid might be the answer,
// but we need to check for smaller values
ans = mid;
high = mid - 1;
}
}
return ans;
}
public static void main(String[] args) {
int[][] mat =
{{10, 20, 30, 40},
{15, 25, 35, 45},
{24, 29, 37, 48},
{32, 33, 39, 50}};
int k = 3;
int result = kthSmallest(mat, k);
System.out.println(result);
}
}
# function to count the number of elements
# less than or equal to x
def countSmallerEqual(mat, x):
n = len(mat)
count = 0
row = 0
col = n - 1
# traverse from the top-right corner
while row < n and col >= 0:
if mat[row][col] <= x:
# if current element is less than
# or equal to x, all elements in this
# row up to the current column are <= x
count += (col + 1)
row += 1
else:
# move left in the mat
col -= 1
return count
# function to find the kth smallest
# element in a sorted 2D mat
def kthSmallest(mat, k):
n = len(mat)
low = mat[0][0]
high = mat[n - 1][n - 1]
ans = 0
while low <= high:
mid = low + (high - low) // 2
# count elements less than or equal to mid
count = countSmallerEqual(mat, mid)
if count < k:
# if there are less than k elements
# <= mid, the kth smallest is larger
low = mid + 1
else:
# otherwise, mid might be the answer,
# but we need to check for smaller values
ans = mid
high = mid - 1
return ans
if __name__ == "__main__":
mat = [
[10, 20, 30, 40],
[15, 25, 35, 45],
[24, 29, 37, 48],
[32, 33, 39, 50]]
k = 3
result = kthSmallest(mat, k)
print(result)
using System;
class GfG {
// function to count the number of elements
// less than or equal to x
static int countSmallerEqual(int[,] mat, int x) {
int n = mat.GetLength(0);
int count = 0;
int row = 0;
int col = n - 1;
// traverse from the top-right corner
while (row < n && col >= 0) {
if (mat[row,col] <= x) {
// if current element is less than
// or equal to x, all elements in this
// row up to the current column are <= x
count += (col + 1);
row++;
}
else {
// Move left in the mat
col--;
}
}
return count;
}
// function to find the kth smallest
// element in a sorted 2D mat
static int kthSmallest(int[,] mat, int k) {
int n = mat.GetLength(0);
int low = mat[0,0];
int high = mat[n - 1,n - 1];
int ans = 0;
while (low <= high) {
int mid = low + (high - low) / 2;
// count elements less than or equal to mid
int count = countSmallerEqual(mat, mid);
if (count < k) {
// if there are less than k elements
// <= mid, the kth smallest is larger
low = mid + 1;
} else {
// otherwise, mid might be the answer,
// but we need to check for smaller values
ans = mid;
high = mid - 1;
}
}
return ans;
}
public static void Main(string[] args) {
int[,] mat =
{{10, 20, 30, 40},
{15, 25, 35, 45},
{24, 29, 37, 48},
{32, 33, 39, 50}};
int k = 3;
int result = kthSmallest(mat, k);
Console.WriteLine(result);
}
}
// function to count the number of elements
// less than or equal to x
function countSmallerEqual(mat, x) {
const n = mat.length;
let count = 0;
let row = 0;
let col = n - 1;
// traverse from the top-right corner
while (row < n && col >= 0) {
if (mat[row][col] <= x) {
// if current element is less than
// or equal to x, all elements in this
// row up to the current column are <= x
count += (col + 1);
row++;
}
else {
// move left in the mat
col--;
}
}
return count;
}
// function to find the kth smallest
// element in a sorted 2D mat
function kthSmallest(mat, k) {
const n = mat.length;
let low = mat[0][0];
let high = mat[n - 1][n - 1];
let ans = 0;
while (low <= high) {
const mid = low + Math.floor((high - low) / 2);
// count elements less than or equal to mid
const count = countSmallerEqual(mat, mid);
if (count < k) {
// if there are less than k elements
// <= mid, the kth smallest is larger
low = mid + 1;
} else {
// otherwise, mid might be the answer,
// but we need to check for smaller values
ans = mid;
high = mid - 1;
}
}
return ans;
}
// Driver Code
const mat =
[[10, 20, 30, 40],
[15, 25, 35, 45],
[24, 29, 37, 48],
[32, 33, 39, 50]];
const k = 3;
const result = kthSmallest(mat, k);
console.log(result);
Output
20
Time Complexity: O(n × log(max(mat) − min(mat))), Binary search is applied over the value range of the matrix, which takes log(max(mat) − min(mat)) steps. Each step calls a counting function that runs in O(n) time by scanning from top-right to bottom-left.
Auxiliary Space: O(1), The algorithm uses only fixed variables and no extra data structures, so space usage is constant.