LeetCode 378. Kth Smallest Element in a Sorted Matrix

Post by ailswan May. 01, 2024

378. Kth Smallest Element in a Sorted Matrix

Problem Statement

link: LeetCode.cn LeetCode Given an n x n matrix where each of the rows and columns is sorted in ascending order, return the kth smallest element in the matrix.

Note that it is the kth smallest element in the sorted order, not the kth distinct element.

You must find a solution with a memory complexity better than O(n2).

Example:

Input: matrix = [[1,5,9],[10,11,13],[12,13,15]], k = 8 Output: 13

Input: matrix = [[-5]], k = 1 Output: -5

n == matrix.length == matrix[i].length 1 <= n <= 300 -109 <= matrix[i][j] <= 10^9 All the rows and columns of matrix are guaranteed to be sorted in non-decreasing order. 1 <= k <= n^2

Solution Approach

The solution uses binary search on the possible values of the matrix elements, combined with a count function to determine the number of elements less than or equal to a midpoint, to efficiently find the k-th smallest element in the sorted matrix.

Algorithm

  1. Binary Search on Values: The algorithm performs binary search on the possible values in the matrix to narrow down the range where the k-th smallest element could be, using the smallest and largest values in the matrix as the search bounds.
  2. Counting Elements: For each midpoint value during the binary search, a helper function counts the number of elements in the matrix that are less than or equal to the midpoint, by traversing the matrix in a way that takes advantage of its sorted properties.
  3. Adjusting Search Range: Based on the count of elements less than or equal to the midpoint, the algorithm adjusts the search range to either include or exclude the midpoint, refining the range until it finds the exact k-th smallest element.

    Complexity

    time complexity: O(n * log(max - min)) space complexity: O(1)

Implement

    class Solution:
        def kthSmallest(self, matrix: List[List[int]], k: int) -> int:
            n = len(matrix)
            def check(mid):
                i, j  = n - 1, 0
                num = 0
                while i >= 0 and j < n:
                    if matrix[i][j] <= mid:
                        num += i + 1
                        j += 1
                    else:
                        i -= 1
                return num >= k
            
            l, r = matrix[0][0], matrix[n - 1][n - 1]
            while l < r:
                mid = (l + r) // 2
                if check(mid):
                    r = mid 
                else:
                    l = mid + 1
            return l