c++algorithmsortingcudathrust

argsort in Thrust


Is it legal to use thrust::sort_by_key like in following code?

#include <thrust/device_vector.h>
#include <thrust/sequence.h>
#include <thrust/iterator/permutation_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/sort.h>
#include <thrust/advance.h>
#include <thrust/copy.h>
    
#include <iterator>
#include <iostream>
    
int main()
{
    int init[] = {2, 0, 1, 3, 4};
    const thrust::device_vector< int > v{std::cbegin(init), std::cend(init)};
        
    thrust::device_vector< std::intptr_t > index{v.size()};
    thrust::sequence(index.begin(), index.end());
        
    auto key = 
        thrust::make_permutation_iterator(
            thrust::make_transform_iterator(
                v.cbegin(), 
                thrust::identity< thrust::tuple< int > >{}),
            index.cbegin());
    thrust::sort_by_key(
        key,
        thrust::next(key, index.size()),
        index.begin());
        
    thrust::copy(
        index.cbegin(), index.cend(),
        std::ostream_iterator< std::intptr_t >(std::cout, ", "));
    std::cout << std::endl;
}

Here index array points to v array of values. I want to have a "sorted view" of v.index after above sorting is the view, that is [v[i] for i in index] (pythonic pseudocode) is sorted.

The trick with identity transformation is crucial here: it transform values pointed to by index in v to one-element-tuple. thrust::tuple is a class and have operator =, which is not cv-ref-qualified for lvalues only and thus can be used on rvalues just returned as a result of dereferencing of the transform_iterator. thrust::tuple< int >(1) = 2; is a legal statement and effectively is a no-op, because left hand side value dropped right after the assignment. As a result key swaps in sort_by_key are all no-ops and real sorting occurs in "values" part of key-value sorting. Also not, that v is immutable here (result v.cbegin() is const iterator).

As I know, developers of Thrust generally assume, that all callables are idempotent. I believe the assumption is not violated here, because only the argument of the callable (thrust::identity) changed, not a state of the callable. But on the other hand any superposition of Thrust fancy iterators can be considered as a composition of a functions (say, permutation_iterator is a simple mapping).

In sort_by_key index is read and written. It can be prohibited by implicit rules of implementation. Is it correct code?


Solution

  • While I cannot answer the question if your implementation is relying on undefined behavior or on algorithmic implementation-details not causing race conditions on index, there is a way of implementing argsort in Thrust using thrust::sort with a custom comparator instead of thrust::sort_by_key that is much more readable and therefore easier to argue about:

    #include <thrust/device_vector.h>
    #include <thrust/sort.h>
    #include <thrust/advance.h>
    #include <thrust/copy.h>
    #include <thrust/iterator/counting_iterator.h>
    
    #include <iterator>
    #include <iostream>
    
    int main()
    {
        int init[] = {2, 0, 1, 3, 4};
        const thrust::device_vector< int > v{std::cbegin(init), std::cend(init)};
    
        // optimization to avoid unnecessary initialization of index to zero
        auto const seq_iter =
            thrust::make_counting_iterator(
                static_cast< std::intptr_t >(0));
    
        thrust::device_vector< std::intptr_t > index{seq_iter,
                                                     thrust::next(seq_iter, v.size())};
        
        auto const v_ptr = v.data();
    
        thrust::sort(
            index.begin(), index.end(),
            [v_ptr] __host__ __device__ (std::intptr_t left_idx, std::intptr_t right_idx)
            {
                return v_ptr[left_idx] < v_ptr[right_idx];
            });
    
        thrust::copy(
            index.cbegin(), index.cend(),
            std::ostream_iterator< std::intptr_t >(std::cout, ", "));
        std::cout << std::endl;
    }
    

    Due to the device lambda, nvcc needs the -extended-lambda flag to compile this snippet. One can implement the comparator for sort with a named functor instead of a lambda to avoid the need for this flag. I used a __host__ __device__ lambda as pure __device__ lambdas are generally problematic.