javaquicksortquickselect

Correct conditions for quickselect


I am implementing the quick-select algorithm to get the kth element in an array, and I am stuck at a place where I don't know how to resolve. Here is my code that doesn't work:

public static void main (String[] args) {
    int[] arr = new int[]{7,6,5,4,3,2,1}; 
    int k = 4;
    quickSort(arr, 0, arr.length - 1, k);
    return arr[k];
}

private static void quickSelect(int[] nums, int start, int end, int k) {
    if (start < end) {
        int partitionIndex = getPartitionIndex(nums, start, end);
        if (partitionIndex == k) {
            return;
        }
        quickSelect(nums, start, partitionIndex - 1, k);
        quickSelect(nums, partitionIndex + 1, end, k);
    }
}

private int getPartitionIndex(int[] nums, int start, int end) {
    int pivot = nums[end];
    int index = start;
    for (int i = start; i <= end; i++) {
        int current = nums[i];
        if (current < pivot) {
            swap(nums, index, i);
            index++;
        }
    }
    swap(nums, index, end);
    return index;
}

private void swap(int[] nums, int i, int j) {
    if (i == j) {
        return;
    }
    nums[i] = nums[i] ^ nums[j];
    nums[j] = nums[i] ^ nums[j];
    nums[i] = nums[i] ^ nums[j];
}

Sure, if I remove these lines:

        if (partitionIndex == k) {
            return;
        }

It becomes quicksort and works fine. And I understand why it's not working, it is since the array I am getting from 0 to k might not be sorted as I return at the above condition. But I am not able to find the right conditions where I sort only the first k elements in the array and leave out the rest, so that I don't do any extra work. I've looked at some implementations online and spent some time on the above, but not able to figure it out, so reaching out for help.


Solution

  • If k < partitionIndex, only check the left partition, else only check the right partition.

            if (k < partitionIndex)
                quickSelect(nums, start, partitionIndex - 1, k);
            else
                quickSelect(nums, partitionIndex + 1, end, k);