cudathrust

Further chance of optimization of Thrust operation of CUDA kernel


I have a CUDA kernel which essentially looks like the following.

__global__ void myOpKernel(double *vals, double *data, int *nums, double *crit, int N, int K) {
  int index = blockIdx.x*blockDim.x + threadIdx.x;

  if (index >= N) return;

  double _crit = crit[index];
  for (int i=0; i<nums[index]; i++) {
    double _res = vals[index*K + i];
    if (data[index*K + i] >= _crit) { _res = 0.0; }

    vals[index*K + i] = _res;
  }
}

This kernel evaluates vals[N*K] based on its data[N*K] compared to crit[N], and the comparison is conducted on the first nums[N] elements of the vals's segment (width K). If the data is smaller than crit, it leaves vals unchanged.

For example, data under consideration will look like the following

  int N = 3;
  int K = 5;

  vals[ 0] = 1.0; data[ 0] = 5.1; crit[0] = 5.0; nums[0] = 3;
  vals[ 1] = 1.0; data[ 1] = 4.9;
  vals[ 2] = 1.0; data[ 2] = 3.0;
  vals[ 3] = 0.0; data[ 3] = 0.0;
  vals[ 4] = 0.0; data[ 4] = 0.0;
  //-----------------------
  vals[ 5] = 1.0; data[ 5] = 2.9; crit[1] = 3.0; nums[1] = 2;
  vals[ 6] = 1.0; data[ 6] = 3.1;
  vals[ 7] = 0.0; data[ 7] = 0.0;
  vals[ 8] = 0.0; data[ 8] = 0.0;
  vals[ 9] = 0.0; data[ 9] = 0.0;
  //-----------------------
  vals[10] = 1.0; data[10] = 8.1; crit[2] = 9.0; nums[2] = 5;
  vals[11] = 1.0; data[11] = 7.8;
  vals[12] = 1.0; data[12] = 9.1;
  vals[13] = 1.0; data[13] = 200.;
  vals[14] = 1.0; data[14] = -1.0;

I noticed that this kind of operation is one of top 3 time-consuming kernels, and am considering Thrust-based acceleration.

What I came up with so far looks like the following. It uses expand provided on Thrust samples (https://github.com/NVIDIA/thrust/blob/master/examples/expand.cu).

struct myOp : public thrust::unary_function<thrust::tuple<double,double,int,int,int,double>, double> {
                                        // vals   data   1/K 1%K nums crit
  __host__ __device__                   // 0      1      2   3   4    5
    double operator() (const thrust::tuple<double,double,int,int,int, double> &t) const {
      double res;

      if (thrust::get<2>(t) >= thrust::get<4>(t)) {
        res = thrust::get<0>(t);  // do nothing
      }else {
        if (thrust::get<3>(t) >= thrust::get<4>(t)) {
          res = thrust::get<0>(t); // do nothing
        }else {
          double tmp = thrust::get<0>(t);
          if (thrust::get<1>(t) >= thrust::get<5>(t)) { tmp = 0.0; }
          res = tmp;
        }
      }

      return res;
    }
};

int main() {

  using namespace thrust::placeholders;

  thrust::device_vector<double> vals(N*K);
  thrust::device_vector<double> data(N*K);
  thrust::device_vector<double> crit(N);
  thrust::device_vector<int>    nums(N);

  thrust::device_vector<double> res(N*K);

  // ... fill values ...

  thrust::device_vector<int>    nums_expand(N*K);
  thrust::device_vector<double> crit_expand(N*K);

  // 'expand()' does something like [1,2,3] -> [1,1,1,2,2,2,3,3,3]
  expand(thrust::constant_iterator<int>(K),
         thrust::constant_iterator<int>(K)+N,
         nums.begin(),
         nums_expand.begin());

  expand(thrust::constant_iterator<int>(K),
         thrust::constant_iterator<int>(K)+N,
         crit.begin(),
         crit_expand.begin());

  thrust::transform(thrust::make_zip_iterator(vals.begin(),
                                              data.begin(),
                                              thrust::make_transform_iterator(thrust::counting_iterator<int>(0), _1/K), // index related to N
                                              thrust::make_transform_iterator(thrust::counting_iterator<int>(0), _1%K), // index related to K
                                              nums_expand.begin(),
                                              crit_expand.begin()),
                    thrust::make_zip_iterator(vals.end(),
                                              data.end(),
                                              thrust::make_transform_iterator(thrust::counting_iterator<int>(0), _1/K) + N*K,
                                              thrust::make_transform_iterator(thrust::counting_iterator<int>(0), _1%K) + N*K,
                                              nums_expand.end(),
                                              crit_expand.end()),
                    res.begin(),
                    myOp());

  ...

}

When I tried this with arbitrary values in the arrays with sets of [N,K] = [1000,256], [10000,256], [50000,256], [100000,256], already the performance is satisfactory.

enter image description here

But I wonder if there is any further chance of speed-up with my Thrust operations. I am expanding some values to take them into if statements, but maybe this can be avoided by permutation_iterator and so on, but I cannot come up with how. Also, I am doing _1/K, _1%K stuff to get the global and local index of the elements, which could be somehow avoided with more clever mind.

At least, for the cosmetics point of view, I would love to insert expand(...) into thrust::transform(...) directly without having to define another vector such as nums_expand.

Any suggestions for any chance of improvements are welcome.

Full code used for the comparison

//https://stackoverflow.com/questions/31955505/can-thrust-transform-reduce-work-with-2-arrays%5B/url%5D

#include <thrust/device_vector.h>

#include <thrust/reduce.h>
#include <thrust/gather.h>
#include <thrust/copy.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/discard_iterator.h>
//#include <thrust/execution_policy.h>
#include <iostream>
#include <iomanip>

#include <thrust/transform.h>
#include <thrust/functional.h>

#include <helper_timer.h>

/////////  https://github.com/NVIDIA/thrust/blob/master/examples/expand.cu //////////
template <typename InputIterator1,
          typename InputIterator2,
          typename OutputIterator>
OutputIterator expand(InputIterator1 first1,
                      InputIterator1 last1,
                      InputIterator2 first2,
                      OutputIterator output)
{
  typedef typename thrust::iterator_difference<InputIterator1>::type difference_type;
  
  difference_type input_size  = thrust::distance(first1, last1);
  difference_type output_size = thrust::reduce(first1, last1);

  // scan the counts to obtain output offsets for each input element
  thrust::device_vector<difference_type> output_offsets(input_size, 0);
  thrust::exclusive_scan(first1, last1, output_offsets.begin()); 

  // scatter the nonzero counts into their corresponding output positions
  thrust::device_vector<difference_type> output_indices(output_size, 0);
  thrust::scatter_if
    (thrust::counting_iterator<difference_type>(0),
     thrust::counting_iterator<difference_type>(input_size),
     output_offsets.begin(),
     first1,
     output_indices.begin());

  // compute max-scan over the output indices, filling in the holes
  thrust::inclusive_scan
    (output_indices.begin(),
     output_indices.end(),
     output_indices.begin(),
     thrust::maximum<difference_type>());

  // gather input values according to index array (output = first2[output_indices])
  thrust::gather(output_indices.begin(),
                 output_indices.end(),
                 first2,
                 output);

  // return output + output_size
  thrust::advance(output, output_size);
  return output;
}

/////////////////////////////////////////////////////////////////////////////////////

template<typename T>
void print_vector(T& vec) {
  for (const auto& elem : vec) {
    std::cout << std::setw(5) << elem; 
  }
  std::cout << std::endl;
}

void printSdkTimer(StopWatchInterface **timer, int average) {
  float fAvgSeconds =
    ((float)1.0e-3 * (float)sdkGetTimerValue(timer) / (float)average);
  printf(" - Elapsed time: %.5f sec \n", fAvgSeconds);
}

struct myOp : public thrust::unary_function<thrust::tuple<double,double,int,int,int,double>, double> {
                                  // vals   data   1/K 1%K nums crit
  __host__ __device__             // 0      1      2   3   4    5
    double operator() (const thrust::tuple<double,double,int,int,int, double> &t) const {
      double res;

      if (thrust::get<2>(t) >= thrust::get<4>(t)) {
        res = thrust::get<0>(t);  // do nothing
      }else {
        if (thrust::get<3>(t) >= thrust::get<4>(t)) {
          res = thrust::get<0>(t); // do nothing
        }else {
          double tmp = thrust::get<0>(t);
          if (thrust::get<1>(t) >= thrust::get<5>(t)) { tmp = 0.0; }
          res = tmp;
        }
      }

      return res;
    }
};

__global__ void myOpKernel(double *vals, double *data, int *nums, double *crit, int N, int K) {
  int index = blockIdx.x*blockDim.x + threadIdx.x;

  if (index >= N) return;

  double _crit = crit[index];
  for (int i=0; i<nums[index]; i++) {
    double _res = vals[index*K + i];
    if (data[index*K + i] >= _crit) { _res = 0.0; }  

    vals[index*K + i] = _res;
  }
}

int main(int argc, char **argv) {

  using namespace thrust::placeholders;

  int N = atoi(argv[1]); 
  int K = atoi(argv[2]); 

  std::cout << "N " << N << " K " << K << std::endl;

  thrust::device_vector<double> vals(N*K);
  thrust::device_vector<double> data(N*K);
  thrust::device_vector<double> crit(N);
  thrust::device_vector<int>    nums(N);

  thrust::device_vector<double> res(N*K);

  for (int i=0; i<N; i++) {
    crit[i] = 101.0; // arbitrary
    nums[i] = 200;   // arbitrary number less than 256
    for (int j=0; j<K; j++) {
      vals[i*K + j] = (double)(i*K + j); // arbitrary
      data[i*K + j] = (double)(i*K + j); // arbitrary
    }
  }

  // to be used for kernel
  thrust::device_vector<double> vals2 = vals;
  thrust::device_vector<double> data2 = data;
  thrust::device_vector<double> crit2 = crit;
  thrust::device_vector<int>    nums2 = nums;

  StopWatchInterface *timer=NULL;
 
//--- 1) thrust
  thrust::device_vector<int>    nums_expand(N*K);
  thrust::device_vector<double> crit_expand(N*K);

  expand(thrust::constant_iterator<int>(K),
         thrust::constant_iterator<int>(K)+N,
         nums.begin(),
         nums_expand.begin());

  expand(thrust::constant_iterator<int>(K),
         thrust::constant_iterator<int>(K)+N,
         crit.begin(),
         crit_expand.begin());

  sdkCreateTimer(&timer);
  sdkStartTimer(&timer);

  thrust::transform(thrust::make_zip_iterator(vals.begin(), 
                                              data.begin(),
                                              thrust::make_transform_iterator(thrust::counting_iterator<int>(0), _1/K), // for N
                                              thrust::make_transform_iterator(thrust::counting_iterator<int>(0), _1%K), // for K
                                              nums_expand.begin(),
                                              crit_expand.begin()),
                    thrust::make_zip_iterator(vals.end(), 
                                              data.end(),
                                              thrust::make_transform_iterator(thrust::counting_iterator<int>(0), _1/K) + N*K, 
                                              thrust::make_transform_iterator(thrust::counting_iterator<int>(0), _1%K) + N*K, 
                                              nums_expand.end(),
                                              crit_expand.end()),
                    res.begin(),
                    myOp());

  sdkStopTimer(&timer);
  printSdkTimer(&timer,1);

  cudaDeviceSynchronize();
  sdkResetTimer(&timer);
  sdkStartTimer(&timer);

//--- 2) kernel
  double *raw_vals2 = thrust::raw_pointer_cast(vals2.data());
  double *raw_data2 = thrust::raw_pointer_cast(data2.data());
  double *raw_crit2 = thrust::raw_pointer_cast(crit2.data());
  int    *raw_nums2 = thrust::raw_pointer_cast(nums2.data());

  int Nthreads = 256;
  int Nblocks = (N*K - 1) / Nthreads + 1;
  myOpKernel<<<Nblocks,Nthreads>>>(raw_vals2, raw_data2, raw_nums2, raw_crit2, N, K);

  cudaDeviceSynchronize();

  sdkStopTimer(&timer);
  printSdkTimer(&timer,1);

  sdkDeleteTimer(&timer);

  return 0;
}

Solution

  • Below is some modified benchmark code with improved kernels. Compiled with nvcc --extended-lambda -arch=sm_89 -O3 main.cu -o main

    Since the timer is not included in your code, I use cudaEvents instead. Data buffers are initialized in host vectors to avoid millions of memcopies. I also noticed that the thrust approach does not produce identical results to your kernel for large N.

    I added two kernels. myOpKernel3 simply uses 1 threadblock per index to access the num[index] values.

    myOpKernel4 uses 1 thread per output element. this requires a prefix sum of nums, and the computation of index per thread. I chose to precompute the indices. An alternative approach would be to perform a binary search on the prefix sum within the kernel.

    For full segments, i.e. nums[i] = 256, the output is

    N 100000 K 256
    expandtime 6.28806 ms
    thrusttransformtime 1.05165 ms
    myOpKerneltime 2.04301 ms
    results from thrust and myOpKernel do not match
    myOpKerneltime3 0.662016 ms
    myOpKernel4_setuptime 0.723872 ms
    myOpKerneltime4 0.785408 ms
    

    For nums[i] = 128

    N 100000 K 256
    expandtime 6.2976 ms
    thrusttransformtime 1.05472 ms
    myOpKerneltime 1.04054 ms
    results from thrust and myOpKernel do not match
    myOpKerneltime3 0.337952 ms
    myOpKernel4_setuptime 0.369152 ms
    myOpKerneltime4 0.386048 ms
    

    nums[i] = 4

    N 100000 K 256
    expandtime 6.2936 ms
    thrusttransformtime 1.05165 ms
    myOpKerneltime 0.273536 ms
    results from thrust and myOpKernel do not match
    myOpKerneltime3 0.119104 ms
    myOpKernel4_setuptime 0.293824 ms
    myOpKerneltime4 0.066848 ms
    

    I did not test non-uniform segment sizes. Note that performance costs of temporary memory allocations and thrust calls can be reduced by using custom allocators in conjunction with thrust's thrust::cuda::par_nosync execution policy.

    //https://stackoverflow.com/questions/31955505/can-thrust-transform-reduce-work-with-2-arrays%5B/url%5D
    
    #include <thrust/device_vector.h>
    
    #include <thrust/reduce.h>
    #include <thrust/gather.h>
    #include <thrust/copy.h>
    #include <thrust/iterator/transform_iterator.h>
    #include <thrust/iterator/counting_iterator.h>
    #include <thrust/iterator/discard_iterator.h>
    #include <thrust/execution_policy.h>
    #include <thrust/host_vector.h>
    
    #include <iostream>
    #include <iomanip>
    
    #include <thrust/transform.h>
    #include <thrust/functional.h>
    
    
    /////////  https://github.com/NVIDIA/thrust/blob/master/examples/expand.cu //////////
    template <typename InputIterator1,
              typename InputIterator2,
              typename OutputIterator>
    OutputIterator expand(InputIterator1 first1,
                          InputIterator1 last1,
                          InputIterator2 first2,
                          OutputIterator output)
    {
      typedef typename thrust::iterator_difference<InputIterator1>::type difference_type;
      
      difference_type input_size  = thrust::distance(first1, last1);
      difference_type output_size = thrust::reduce(first1, last1);
    
      // scan the counts to obtain output offsets for each input element
      thrust::device_vector<difference_type> output_offsets(input_size, 0);
      thrust::exclusive_scan(first1, last1, output_offsets.begin()); 
    
      // scatter the nonzero counts into their corresponding output positions
      thrust::device_vector<difference_type> output_indices(output_size, 0);
      thrust::scatter_if
        (thrust::counting_iterator<difference_type>(0),
         thrust::counting_iterator<difference_type>(input_size),
         output_offsets.begin(),
         first1,
         output_indices.begin());
    
      // compute max-scan over the output indices, filling in the holes
      thrust::inclusive_scan
        (output_indices.begin(),
         output_indices.end(),
         output_indices.begin(),
         thrust::maximum<difference_type>());
    
      // gather input values according to index array (output = first2[output_indices])
      thrust::gather(output_indices.begin(),
                     output_indices.end(),
                     first2,
                     output);
    
      // return output + output_size
      thrust::advance(output, output_size);
      return output;
    }
    
    /////////////////////////////////////////////////////////////////////////////////////
    
    template<typename T>
    void print_vector(T& vec) {
      for (const auto& elem : vec) {
        std::cout << std::setw(5) << elem; 
      }
      std::cout << std::endl;
    }
    
    struct myOp : public thrust::unary_function<thrust::tuple<double,double,int,int,int,double>, double> {
                                      // vals   data   1/K 1%K nums crit
      __host__ __device__             // 0      1      2   3   4    5
        double operator() (const thrust::tuple<double,double,int,int,int, double> &t) const {
          double res;
    
          if (thrust::get<2>(t) >= thrust::get<4>(t)) {
            res = thrust::get<0>(t);  // do nothing
          }else {
            if (thrust::get<3>(t) >= thrust::get<4>(t)) {
              res = thrust::get<0>(t); // do nothing
            }else {
              double tmp = thrust::get<0>(t);
              if (thrust::get<1>(t) >= thrust::get<5>(t)) { tmp = 0.0; }
              res = tmp;
            }
          }
    
          return res;
        }
    };
    
    __global__ void myOpKernel(double *vals, double *data, int *nums, double *crit, int N, int K) {
      int index = blockIdx.x*blockDim.x + threadIdx.x;
    
      if (index >= N) return;
    
      double _crit = crit[index];
      for (int i=0; i<nums[index]; i++) {
        double _res = vals[index*K + i];
        if (data[index*K + i] >= _crit) { _res = 0.0; }  
    
        vals[index*K + i] = _res;
      }
    }
    
    
      __global__ void myOpKernel3(
          double * __restrict__ vals, 
          const double * __restrict__ data, 
          const int * __restrict__ nums, 
          const double * __restrict__ crit, 
          int N, 
          int K
        ){
            for(int index = blockIdx.x; index < N; index += gridDim.x){   
                const double _crit = crit[index];
                const int num = nums[index];
                for(int i = threadIdx.x; i < num; i += blockDim.x){
                    double _res = vals[index*K + i];
                    if (data[index*K + i] >= _crit) { _res = 0.0; }          
                    vals[index*K + i] = _res;
                }
            }
      }
    
      __global__ 
      void myOpKernel4(
        double * __restrict__ vals, 
        const double * __restrict__ data, 
        const int * __restrict__ nums, 
        const double * __restrict__ crit, 
        const int* __restrict__ numsPrefixSum,
        const int* __restrict__ indexForThread,
        int totalnums,
        int N, 
        int K
      ){
          const int tid = threadIdx.x + blockIdx.x * blockDim.x;
          const int numValid = totalnums;
          if(tid < numValid){
              const int index = indexForThread[tid];
              const int i = tid - numsPrefixSum[index];
              const double _crit = crit[index];
              double _res = vals[index*K + i];
                if (data[index*K + i] >= _crit) { _res = 0.0; }          
                vals[index*K + i] = _res;
          }
    }
    
    int main(int argc, char **argv) {
    
      using namespace thrust::placeholders;
    
      int N = atoi(argv[1]); 
      int K = atoi(argv[2]); 
    
      std::cout << "N " << N << " K " << K << std::endl;
    
      thrust::host_vector<double> h_vals(N*K);
      thrust::host_vector<double> h_data(N*K);
      thrust::host_vector<double> h_crit(N);
      thrust::host_vector<int>    h_nums(N);
    
      for (int i=0; i<N; i++) {
        h_crit[i] = 101.0; // arbitrary
        h_nums[i] = 4;   // arbitrary number less than 256
        for (int j=0; j<K; j++) {
            h_vals[i*K + j] = (double)(i*K + j); // arbitrary
            h_data[i*K + j] = (double)(i*K + j); // arbitrary
        }
      }
    
      thrust::device_vector<double> vals = h_vals;
      thrust::device_vector<double> data = h_data;
      thrust::device_vector<double> crit = h_crit;
      thrust::device_vector<int>    nums = h_nums;
    
      thrust::device_vector<double> res(vals.size());
    
    
    
      cudaEvent_t eventA; cudaEventCreate(&eventA);
      cudaEvent_t eventB; cudaEventCreate(&eventB);
     
      //--- 1) thrust
      cudaEventRecord(eventA);
      thrust::device_vector<int>    nums_expand(N*K);
      thrust::device_vector<double> crit_expand(N*K);
    
    
      expand(thrust::constant_iterator<int>(K),
             thrust::constant_iterator<int>(K)+N,
             nums.begin(),
             nums_expand.begin());
    
      expand(thrust::constant_iterator<int>(K),
             thrust::constant_iterator<int>(K)+N,
             crit.begin(),
             crit_expand.begin());
    
      cudaEventRecord(eventB);
      cudaEventSynchronize(eventB);
      float expandtime; cudaEventElapsedTime(&expandtime, eventA, eventB);
      std::cout << "expandtime " << expandtime << " ms\n";
    
      cudaEventRecord(eventA);
    
      thrust::transform(thrust::make_zip_iterator(vals.begin(), 
                                                  data.begin(),
                                                  thrust::make_transform_iterator(thrust::counting_iterator<int>(0), _1/K), // for N
                                                  thrust::make_transform_iterator(thrust::counting_iterator<int>(0), _1%K), // for K
                                                  nums_expand.begin(),
                                                  crit_expand.begin()),
                        thrust::make_zip_iterator(vals.end(), 
                                                  data.end(),
                                                  thrust::make_transform_iterator(thrust::counting_iterator<int>(0), _1/K) + N*K, 
                                                  thrust::make_transform_iterator(thrust::counting_iterator<int>(0), _1%K) + N*K, 
                                                  nums_expand.end(),
                                                  crit_expand.end()),
                        res.begin(),
                        myOp());
    
    
    cudaEventRecord(eventB);
    cudaEventSynchronize(eventB);
    float thrusttransformtime; cudaEventElapsedTime(&thrusttransformtime, eventA, eventB);
    std::cout << "thrusttransformtime " << thrusttransformtime << " ms\n";
      cudaDeviceSynchronize();
    
    //   std::cout << "vals after thrust\n";
    //   for(int i = 0; i < res.size(); i++){
    //     std::cout << res[i] << " ";
    //   }
    //   std::cout << "\n";
    
    
    //--- 2) kernel
    thrust::device_vector<double> vals2 = h_vals;
    thrust::device_vector<double> data2 = h_data;
    thrust::device_vector<double> crit2 = h_crit;
    thrust::device_vector<int>    nums2 = h_nums;
    
      cudaEventRecord(eventA);
    
      int Nthreads = 256;
      int Nblocks = (N*K - 1) / Nthreads + 1;
      myOpKernel<<<Nblocks,Nthreads>>>(vals2.data().get(), data2.data().get(), nums2.data().get(), crit2.data().get(), N, K);
    
      cudaEventRecord(eventB);
    cudaEventSynchronize(eventB);
    float myOpKerneltime; cudaEventElapsedTime(&myOpKerneltime, eventA, eventB);
    std::cout << "myOpKerneltime " << myOpKerneltime << " ms\n";
    
      cudaDeviceSynchronize();
    
      if(res == vals2){
          std::cout << "results from thrust and myOpKernel match\n";
      }else{
        std::cout << "results from thrust and myOpKernel do not match\n";
      }
    
    
      
    
      {
        //1 block per index
    
        thrust::device_vector<double> vals_new = h_vals;
          thrust::device_vector<double> data_new = h_data;
          thrust::device_vector<double> crit_new = h_crit;
          thrust::device_vector<int>    nums_new = h_nums;
    
          cudaEventRecord(eventA);
    
          int Nthreads = 256;
          int Nblocks = N;
          myOpKernel3<<<Nblocks,Nthreads>>>(vals_new.data().get(), data_new.data().get(), nums_new.data().get(), crit_new.data().get(), N, K);
    
          cudaEventRecord(eventB);
          cudaEventSynchronize(eventB);
          float myOpKerneltime3; cudaEventElapsedTime(&myOpKerneltime3, eventA, eventB);
          std::cout << "myOpKerneltime3 " << myOpKerneltime3 << " ms\n";
    
          cudaDeviceSynchronize();
    
          assert(vals_new == vals2);
        }
          {
            //1 thread per output position
        
            thrust::device_vector<double> vals_new = h_vals;
              thrust::device_vector<double> data_new = h_data;
              thrust::device_vector<double> crit_new = h_crit;
              thrust::device_vector<int>    nums_new = h_nums;
    
              
              cudaEventRecord(eventA);
              thrust::device_vector<int> numsPrefixSum(N+1);
              numsPrefixSum[0] = 0;
              thrust::inclusive_scan(
                  nums_new.begin(),
                  nums_new.end(),
                  numsPrefixSum.begin() + 1
                );
            const int totalNums = numsPrefixSum.back();
            thrust::device_vector<int> indexForThread(totalNums, 0);
    
              thrust::scatter_if(
                    thrust::make_counting_iterator(0),
                    thrust::make_counting_iterator(0) + N, 
                    numsPrefixSum.begin(),
                    thrust::make_transform_iterator(
                        nums_new.begin(), 
                        [] __host__ __device__ (int i){return i > 0;}
                    ),
                    indexForThread.begin()
                );
            
                thrust::inclusive_scan(
                    indexForThread.begin(), 
                    indexForThread.begin() + totalNums, 
                    indexForThread.begin(), 
                    thrust::maximum<int>{}
                );
    
              cudaEventRecord(eventB);
              cudaEventSynchronize(eventB);
              float myOpKernel4_setuptime; cudaEventElapsedTime(&myOpKernel4_setuptime, eventA, eventB);
              std::cout << "myOpKernel4_setuptime " << myOpKernel4_setuptime << " ms\n";
        
              cudaEventRecord(eventA);
        
              int Nthreads = 256;
              int Nblocks = (totalNums + Nthreads - 1) / Nthreads;
              myOpKernel4<<<Nblocks,Nthreads>>>(
                vals_new.data().get(), 
                data_new.data().get(), 
                nums_new.data().get(), 
                crit_new.data().get(),
                numsPrefixSum.data().get(),
                indexForThread.data().get(),
                totalNums,
                N, 
                K
              );
    
              cudaEventRecord(eventB);
              cudaEventSynchronize(eventB);
              float myOpKerneltime4; cudaEventElapsedTime(&myOpKerneltime4, eventA, eventB);
              std::cout << "myOpKerneltime4 " << myOpKerneltime4 << " ms\n";
        
              cudaDeviceSynchronize();
        
              assert(vals_new == vals2);
            }
    
      cudaEventDestroy(eventA);
      cudaEventDestroy(eventB);
    
      return 0;
    }