c++stdmediannth-element

Efficient median computation


I have an array, A, of length n. Let B be an array (that we never want to store separately - this is just to help explain) containing every k'th element of A. I want to find the median of B, and I want to move that element of A to the floor(n/2)'th position in A.

How can I do this efficiently? I'm thinking of trying to make a single call to std::nth_element, passing a pointer to A. However, I need this pointer to increment by k elements of A. How do I do this? Essentially:

A2 = (kFloat *)A;
std::nth_element(A2, A2 + (n/k)/2, A2 + (n/k));
swap(A[ ((n/k)/2)*k ], A[n/2]); // This might be redundant

where kFloat would be a structure that acts like a float, but when you increment the pointer it moves k*sizeof(float) in memory.

Note: I do not require the true median (average of middle two when n is even).

Edit: Another way of saying what I want (doesn't compile, because k is not a constant):

std::nth_element((float[k] * )A, ((float[k] * ) A)[(n / k) / 2], ((float[k] * ) A)[n / k]);

Edit 2: I am changing algorithm.cc, so I don't want to introduce dependencies on a library like Boost. I would like to use core C++11 functionality + std only.


Solution

  • For anyone else who has this problem in the future, I've modified some functions from algorithm.cc to include a stride parameter. Many of them assume that _First and _Last span a multiple of your stride, so I don't recommend calling them. However, you can call the following function:

    // Same as _Nth_element, but increments pointers by strides of k
    // Takes n, rather than last (needed to avoid confusion about what last should be [see line that computes _Last to see why]
    // _First = pointer to start of the array
    // _Nth = pointer to the position that we want to find the element for (if it were sorted).
    //          This position should be = _First + k*x, for some integer x. That is, it should be a multiple of k.
    // n = Length of array, _First, in primitive type (not length / k).
    // _Pred = comparison operator. Typically use less<>()
    // k = integer specifying the stride. If k = 10, we consider elements 0, 10, 20... only.
    template<class _RanIt, class intType, class _Pr> inline
    void _Nth_element_strided(_RanIt _First, _RanIt _Nth, intType n, _Pr _Pred, intType k);
    

    To call this function, you need to include this header:

    #ifndef _NTH_ELEMENT_STRIDED_H_
    #define _NTH_ELEMENT_STRIDED_H_
    
    template<class _RanIt, class intType, class _Pr> inline
    void _Median_strided(_RanIt _First, _RanIt _Mid, _RanIt _Last, _Pr _Pred, intType k) {
        // sort median element to middle
        if (40 < (_Last - _First)/k) {
            // median of nine
            size_t _Step = k * ((_Last - _First + k) / (k*8));
            _Med3(_First, _First + _Step, _First + 2 * _Step, _Pred);
            _Med3(_Mid - _Step, _Mid, _Mid + _Step, _Pred);
            _Med3(_Last - 2 * _Step, _Last - _Step, _Last, _Pred);
            _Med3(_First + _Step, _Mid, _Last - _Step, _Pred);
        }
        else
            _Med3(_First, _Mid, _Last, _Pred);
    }
    
    // Same as _Unguarded_partition, except it increments pointers by k.
    template<class _RanIt, class _Pr, class intType> inline
    pair<_RanIt, _RanIt> _Unguarded_partition_strided(_RanIt _First, _RanIt _Last, _Pr _Pred, intType k) {
        // partition [_First, _Last), using _Pred
        _RanIt _Mid = _First + (((_Last - _First)/k) / 2)*k;
        _Median_strided(_First, _Mid, _Last - k, _Pred, k);
        _RanIt _Pfirst = _Mid;
        _RanIt _Plast = _Pfirst + k;
    
        while (_First < _Pfirst
            && !_DEBUG_LT_PRED(_Pred, *(_Pfirst - k), *_Pfirst)
            && !_Pred(*_Pfirst, *(_Pfirst - k)))
            _Pfirst -= k;
        while (_Plast < _Last
            && !_DEBUG_LT_PRED(_Pred, *_Plast, *_Pfirst)
            && !_Pred(*_Pfirst, *_Plast))
            _Plast += k;
    
        _RanIt _Gfirst = _Plast;
        _RanIt _Glast = _Pfirst;
    
        for (;;) {
            // partition
            for (; _Gfirst < _Last; _Gfirst += k) {
                if (_DEBUG_LT_PRED(_Pred, *_Pfirst, *_Gfirst))
                    ;
                else if (_Pred(*_Gfirst, *_Pfirst))
                    break;
                else if (_Plast != _Gfirst) {
                    _STD iter_swap(_Plast, _Gfirst);
                    _Plast += k;
                }
                else
                    _Plast += k;
            }
            for (; _First < _Glast; _Glast -= k) {
                if (_DEBUG_LT_PRED(_Pred, *(_Glast - k), *_Pfirst))
                    ;
                else if (_Pred(*_Pfirst, *(_Glast - k)))
                    break;
                else {
                    _Pfirst -= k;
                    if (_Pfirst != _Glast - k)
                        _STD iter_swap(_Pfirst, _Glast - k);
                }
            }
    
            if (_Glast == _First && _Gfirst == _Last)
                return (pair<_RanIt, _RanIt>(_Pfirst, _Plast));
    
            if (_Glast == _First) {
                // no room at bottom, rotate pivot upward
                if (_Plast != _Gfirst)
                    _STD iter_swap(_Pfirst, _Plast);
                _Plast += k;
                _STD iter_swap(_Pfirst, _Gfirst);
                _Pfirst += k;
                _Gfirst += k;
            }
            else if (_Gfirst == _Last) {
                // no room at top, rotate pivot downward
                _Glast -= k;
                _Pfirst -= k;
                if (_Glast != _Pfirst)
                    _STD iter_swap(_Glast, _Pfirst);
                _Plast -= k;
                _STD iter_swap(_Pfirst, _Plast);
            }
            else {
                _Glast -= k;
                _STD iter_swap(_Gfirst, _Glast);
                _Gfirst += k;
            }
        }
    }
    
    // TEMPLATE FUNCTION move_backward
    template<class _BidIt1, class _BidIt2, class intType> inline
    _BidIt2 _Move_backward_strided(_BidIt1 _First, _BidIt1 _Last, _BidIt2 _Dest, intType k) {
        // move [_First, _Last) backwards to [..., _Dest), arbitrary iterators
        while (_First != _Last) {
            _Dest -= k;
            _Last -= k;
            *_Dest = _STD move(*_Last);
        }
        return (_Dest);
    }
    
    template<class _BidIt, class _Pr, class intType, class _Ty> inline
    void _Insertion_sort1_strided(_BidIt _First, _BidIt _Last, _Pr _Pred, _Ty *, intType k) {
        // insertion sort [_First, _Last), using _Pred
        if (_First != _Last) {
            for (_BidIt _Next = _First + k; _Next != _Last;) {
                // order next element
                _BidIt _Next1 = _Next;
                _Ty _Val = _Move(*_Next);
    
                if (_DEBUG_LT_PRED(_Pred, _Val, *_First)) {
                    // found new earliest element, move to front
                    _Next1 += k;
                    _Move_backward_strided(_First, _Next, _Next1, k);
                    *_First = _Move(_Val);
                }
                else {
                    for (_BidIt _First1 = _Next1 - k; _DEBUG_LT_PRED(_Pred, _Val, *_First1);) {
                        *_Next1 = _Move(*_First1);  // move hole down
    
                        _Next1 = _First1;
                        _First1 -= k;
                    }
                    *_Next1 = _Move(_Val);  // insert element in hole
                }
                _Next += k;
            }
        }
    }
    
    // _Last should point to the last element being considered (the last k'th element), plus k.
    template<class _BidIt, class intType, class _Pr> inline
    void _Insertion_sort_strided(_BidIt _First, _BidIt _Last, _Pr _Pred, intType k) {
        // insertion sort [_First, _Last), using _Pred
        _Insertion_sort1_strided(_First,_Last, _Pred, _Val_type(_First), k);
    }
    
    // Same as _Nth_element, but increments pointers by strides of k
    // Takes n, rather than last (needed to avoid confusion about what last should be [see first line below]
    // _First = pointer to start of the array
    // _Nth = pointer to the position that we want to find the element for (if it were sorted).
    //          This position should be = _First + k*x, for some integer x. That is, it should be a multiple of k.
    // n = Length of array, _First, in primitive type (not length / k).
    // _Pred = comparison operator. Typically use less<>()
    // k = integer specifying the stride. If k = 10, we consider elements 0, 10, 20... only.
    template<class _RanIt, class intType, class _Pr> inline
    void _Nth_element_strided(_RanIt _First, _RanIt _Nth, intType n, _Pr _Pred, intType k) {
    
        _RanIt _Last = (n % k == 0 ? _First + n : _First + (n / k + 1)*k);
        // order Nth element, using _Pred
        for (; _ISORT_MAX < (_Last - _First) / k;) {
            // divide and conquer, ordering partition containing Nth
            pair<_RanIt, _RanIt> _Mid = _Unguarded_partition_strided(_First, _Last, _Pred, k);
    
            if (_Mid.second <= _Nth)
                _First = _Mid.second;
            else if (_Mid.first <= _Nth)
                return; // Nth inside fat pivot, done
            else
                _Last = _Mid.first;
        }
    
        _Insertion_sort_strided(_First, _Last, _Pred, k);   // sort any remainder
    }
    
    #endif
    

    An example of using this function:

        for (int counter = 0; true; counter++) {
            // Test strided methods
            int n = (rand() % 10000) + 1;
            int k = (rand() % n) + 1;
            int * a = new int[n];
            int bLen = (n % k == 0 ? n / k : n / k + 1);
            int * b = new int[bLen];
            for (int i = 0; i < n; i++) // Initialize randomly
                a[i] = rand() % 100;
            for (int i = 0; i < bLen; i++)
                b[i] = a[i*k];
    
            int index = rand() % (bLen);    // Random index!
            _Nth_element(b, b + index, b + bLen, less<>());
            _Nth_element_strided(a, a + index*k, n, less<>(), k);
    
            if (b[index] != a[index*k]) {
                cout << "Not equal!" << endl;
                cout << b[index] << '\t' << a[index*k] << endl;
                getchar();
            }
            else
                cout << counter << endl;
        }