c++algorithmnth-element

What is nth_element and what does it do exactly? and how to implement it


I've almost understood many STL algorithms until I've reached the algorithm std::nth_element. I 'm stuck on it; I don't know how it works and it does do exactly.

For education and understanding sake can someone explain to me how the algorithm std::nth_element works?

std::vector<int> v{ 9, 3, 6, 2, 1, 7, 8, 5, 4, 0 };
std::nth_element(v.begin(), v.begin() + 2, v.end());

for (auto i : v)
    std::cout << i << " ";
std::cout << '\n';

The output:

1 0 2 3 6 7 8 5 4 9 

Here is some explanation from cppreference.com:

nth_element is a partial sorting algorithm that rearranges elements in [first, last) such that:

  • The element pointed at by nth is changed to whatever element would occur in that position if [first, last) was sorted.
  • All of the elements before this new nth element are less than or equal to the elements after the new nth element. More formally, nth_element partially sorts the range [first, last) in ascending order so that the condition !(*j < *i) (for the first version, or comp(*j, *i) == false for the second version) is met for any i in the range [first, nth) and for any j in the range [nth, last). The element placed in the nth position is exactly the element that would occur in this position if the range was fully sorted.

nth may be the end iterator, in this case the function has no effect.


Solution

  • So where is nth element here?

    The n-th element is the 2 at index 2 because thats what you asked for when you passed begin()+2.

    The element pointed at by nth is changed to whatever element would occur in that position if [first, last) was sorted.

    This means that, if the vector was sorted, the order of elements would be

    0 1 2 3 4 5 6 7 8 9 
        ^--- begin() + 2
    

    You asked to have 3rd largest element at index 2 (3rd position), and thats what the algorithm does.

    In addition it puts all elements smaller in the front and all elements larger in the back:

    !(*j < *i) (for the first version, or comp(*j, *i) == false for the second version) is met for any i in the range [first, nth) and for any j in the range [nth, last).

    Let's use indices rather than iterators, then for any i < 2 and any j > 2 it holds that v[i] < v[j]. In other words, 1 and 0 are both smaller than any element in 2 3 6 7 8 5 4 9.