DEV Community

Cover image for Kth Largest Element in an Array - Quickselect Using Lomuto Partitioning Scheme.
Georgii Kliukovkin
Georgii Kliukovkin

Posted on

Kth Largest Element in an Array - Quickselect Using Lomuto Partitioning Scheme.

According to Leetcode:

We highly recommend Kth Largest Element in an Array, which has been asked many times in an Amazon phone interview.

Task

Given an integer array nums and an integer k, return the kth largest element in the array.
Note that it is the kth largest element in the sorted order, not the kth distinct element.
You must solve it in O(n) time complexity.

Heap

There are several ways to solve this task. The most straightforward is to iterate through the array and keep the top Kth largest element while iterating. For that purpose, a heap is the best data structure.
The algorithm:

  1. iterate through the array and put each element to the heap
  2. if the heap size is greater than k - remove the smallest one(which will be at the beginning of our queue)
  3. return the first element
class Solution {
    public int findKthLargest(int[] nums, int k) {
        PriorityQueue<Integer> heap = new PriorityQueue<>();
        for (int n: nums) {
          heap.add(n);
          if (heap.size() > k)
            heap.poll();
        }

        return heap.poll();        
  }
}
Enter fullscreen mode Exit fullscreen mode

We iterate through the array and each time we insert the element in the heap of length K, so the time complexity is O(NlogK). Can we do better? :)

Quickselect time complexity

First, let’s discuss high-level design, and then we can dive deeper. Here is a link to the problem.
We need to find K-th largest element in the array. Since for us more comfortable sorting an array in non-decreasing order, let’s rephrase the task and say that we need to find the N-k smallest element in the array.

The idea behind the algorithm is to use a partitioning algorithm from quickselect.
We partition the whole array(O(n) time complexity).
Then through away half of it and continue with another half(O(n/2) time complexity).

Again through away half of it and continue with 1/4 of the original array (O(n/4) time complexity).
Continue doing so until we reach a single element. To summarize overall time complexity:

n + n/2 + n/4 + n/8 +... ~ 2n = n
Enter fullscreen mode Exit fullscreen mode

You may be confused about time complexity thinking “wait, we are doing almost the same with binary search, it is logN time complexity, where is the logN part”. Well, you are right. The key point here is that not all partitions are doing the same amount of work.

This more precise analysis, which uses the fact that the work done keeps decreasing on each iteration, gives the O(n) runtime.
If you are still confused about time complexity take a look at this answer and this article.

Partition algorithm(Lomuto partitioning scheme)

One chooses a pivot and defines its position in a sorted array in a linear time using so-called partition algorithm.
The toughest thing in this algorithm is understanding how the partition works. Let’s imagine we have an array [2,6,3,4,7,1,8,5]. With the Lomuto partitioning scheme the algorithm will look like this:

  1. choose the last element as a pivot
  2. create 2 pointers, i and startIndex, both starting from the beginning of the target interval.
  3. The first pointer i will scan the whole interval and check the condition - if the value at i is less than the pivot then swap it with the value at the pointer startIndex and increment startIndex
  4. swap the pivot(which is placed at the end of the array) with the value at startIndex (since this is a place for our pivot value) and return startIndex . To achieve that we need a single for loop and dedicated variable startIndex which is equal to the index of the first element in the array(in our case it is 0).

Lomute partition scheme

Now we ended up with storeIndex = 3, which is actually a place for our pivot element. We can swap them and return this pivot index. The array looks like [2,3,4,5,7,6], pivot index is 3.

public void swap(int i, int j) {
  int temp = nums[i];
  nums[i] = nums[j];
  nums[j] = temp;
}

public int partition(int start, int end) {
  int pivot = nums[end];
  int storeIndex = start;
  for (int i = start; i <= end; i++) {
    if (nums[i] < pivot) {
      swap(i, storeIndex);
      storeIndex++;
    }
  }
  //don't forget to move pivot element from the end of the array to its position
  swap(end, storeIndex);
  return storeIndex;
}
Enter fullscreen mode Exit fullscreen mode

Random pivot

The worst case of this algorithm will be O(n^2). Why is that so? The algorithm is sensitive to the pivot that is chosen. Imagine you have already sorted the array and each time you select the first element as a pivot. That means each partition will decrease the range of the elements only by 1. To avoid this, we need to select a random pivot each time:

Random random = new Random();
//asume that 'start' is an index of the first element in search interval of the array
//and 'end' is an index of the last element in that interval, then:
int pivot = left + random.nextInt(right - left);
Enter fullscreen mode Exit fullscreen mode

Quickselect algorithm

At last, we need to implement a quickselect algorithm. The steps are:

  • select a random pivot
  • partition array with this pivot and return its new index
  • if the index is equal to N - k, then we find our value, otherwise select one of the two parts of the array, and repeat the algorithm Here is the source code:
class Solution {
    int[] nums;
    public void swap(int i, int j) {
      int temp = nums[i];
      nums[i] = nums[j];
      nums[j] = temp;
    }

    public int partition(int start, int end, int pivotIndex) {
      int pivot = nums[pivotIndex];
      //move pivot to the end of the array;
      swap(end, pivotIndex);
      int startIndex = start;
      for (int i = start; i <= end; i++) {
        if (nums[i] < pivot) {
          swap(i, startIndex);
          startIndex++;
        }
      }
      //don't forget to move pivot element from the end of the array to its position
      swap(end, startIndex);
      return startIndex;
    }
    public int findKthLargest(int[] nums, int k) {
        this.nums = nums;
        return quickselect(0, nums.length - 1, nums.length - k);
    }
    public int quickselect(int start, int end, int k) {
        if (start == end) {
            return nums[start];
        }
        Random random = new Random();
        int pivotIndex = start + random.nextInt(end - start);
        pivotIndex = partition(start, end, pivotIndex);
        if (pivotIndex == k) {
            return nums[pivotIndex];
        }
        if (pivotIndex < k) {
            return quickselect(pivotIndex + 1, end, k);
        }
        return quickselect(start, pivotIndex - 1, k);
    }
}
Enter fullscreen mode Exit fullscreen mode

Top comments (0)